-- TODO:
-- - Polymorphically recursive data types

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE ViewPatterns #-}

-- This warning is over-eager in TH quotes when the variables that the pattern
-- binds are spliced instead of mentioned directly. See
-- https://gitlab.haskell.org/ghc/ghc/-/issues/22057 . Fixed in GHC 9.6.1.
-- {-# OPTIONS -Wno-unused-pattern-binds #-}

module Language.Haskell.ReverseAD.TH (
  -- * Reverse AD
  reverseAD,
  reverseAD',
  -- * Structure descriptions
  Structure,
  structureFromTypeable,
  structureFromType,
  -- * Special methods
  (|*|),

  -- * Debug
  evlog,
) where

import Control.Concurrent
import Control.Monad (when)
import Control.Parallel (par)
import Data.Bifunctor (first, second)
import Data.Char (isAlphaNum)
import Data.List (zip4, intercalate)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.IORef
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Typeable
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Storable.Mutable as MVS
import Data.Word (Word8, Word16, Word32, Word64)
import GHC.Exts (Multiplicity(..))
import Language.Haskell.TH
import Language.Haskell.TH.Syntax as TH hiding (lift)
import System.Clock
import System.IO.Unsafe

import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import Control.DeepSeq
import Control.Exception (evaluate)

import Control.Monad.IO.Class
-- import Debug.Trace
import System.IO

import Control.Concurrent.ThreadPool
import Data.Vector.Storable.Mutable.CAS
import Language.Haskell.ReverseAD.TH.Orphans ()
import Language.Haskell.ReverseAD.TH.Source as Source
import Language.Haskell.ReverseAD.TH.Translate


-- | Whether to enable debug prints in the differentiation code. This is quite
-- spammy and should be turned off when not actually debugging.
kDEBUG :: Bool
kDEBUG :: Bool
kDEBUG = Bool
False


-- === The program transformation ===
--
-- type DN = (Double, NID, Contrib)
-- type NID = (JobID, Int)
-- data Contrib = Contrib [(NID, Contrib, Double)]
--
-- Dt[Double] = DN
-- Dt[()] = ()
-- Dt[(a, b)] = (Dt[a], Dt[b])
-- Dt[a -> b] = Dt[a] -> FwdM Dt[b]
-- Dt[Int] = Int
-- Dt[T a b c] = T Dt[a] Dt[b] Dt[c]      -- data types, generalises (,)
--
-- Dt[eps] = eps
-- Dt[Γ, x : a] = Dt[Γ], x : Dt[a]
--
-- FwdM is a monad with:
--   gen :: FwdM NID
--   par :: FwdM a -> FwdM b -> FwdM (a, b)
--
-- Γ |- t : a
-- ~> Dt[Γ] |- D[t] : FwdM Dt[a]
-- D[r] = do i <- gen; pure (r, i, Contrib [])
-- D[x] = pure x
-- D[()] = pure ()
-- D[(s, t)] = do x <- D[s]
--                y <- D[t]
--                pure (x, y)
-- D[C s t u] = do x <- D[s]
--                 y <- D[t]
--                 z <- D[u]
--                 pure (C x y z)
-- D[case s of C1 x1 x2 -> t1 ; ... ; Cn x1 x2 -> tn] =
--     do x <- D[s]
--        case x of
--          C1 x1 x2 -> D[t1]
--          ...
--          Cn x1 x2 -> D[tn]
-- D[s t] = do f <- D[s]
--             a <- D[t]
--             f a
-- D[\x -> t] = pure (\x -> D[t])
-- D[let x = s in t] = do x <- D[s]
--                        D[t]
-- D[op t1..tn] =
--   do (x1, i1, cb1) = D[t1]
--      (x2, i2, cb2) = D[t1]
--      ...
--      (xn, in, cbn) = D[tn]
--      i <- gen
--      pure (op x1..xn, i, Contrib [(i1, cb1, dop_1 x1..xn), ..., (in, cbn, dop_n x1..xn)])


-- ----------------------------------------------------------------------
-- Additional API
-- ----------------------------------------------------------------------

-- | Parallel (strict) pair construction.
--
-- The definition of @x |*| y@ is @x \`'par'\` y \`'par'\` (x, y)@: @x@ and @y@
-- are evaluated in parallel. This also means that this pair constructor is, in
-- a certain sense, strict.
--
-- In differentiation using 'reverseAD', this function is specially interpreted
-- so that not only the forward pass, but also the reverse gradient pass runs
-- in parallel.
(|*|) :: a -> b -> (a, b)
a
x |*| :: forall a b. a -> b -> (a, b)
|*| b
y = a
x a -> (a, b) -> (a, b)
forall a b. a -> b -> b
`par` b
y b -> (a, b) -> (a, b)
forall a b. a -> b -> b
`par` (a
x, b
y)


-- ------------------------------------------------------------
-- The monad for the target program
-- ------------------------------------------------------------

-- | The ID of a parallel job, >=0. The implicit main job has ID 0, parallel
-- jobs start from 1.
newtype JobID = JobID Int
  deriving (Int -> JobID -> ShowS
[JobID] -> ShowS
JobID -> String
(Int -> JobID -> ShowS)
-> (JobID -> String) -> ([JobID] -> ShowS) -> Show JobID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JobID -> ShowS
showsPrec :: Int -> JobID -> ShowS
$cshow :: JobID -> String
show :: JobID -> String
$cshowList :: [JobID] -> ShowS
showList :: [JobID] -> ShowS
Show)

data BeforeJob
  = Start  -- ^ Source of (this part of the) graph
  | Fork !JobDescr !JobDescr !JobDescr
      -- ^ a b c: a forked into b and c, which joined into the current job
  deriving (Int -> BeforeJob -> ShowS
[BeforeJob] -> ShowS
BeforeJob -> String
(Int -> BeforeJob -> ShowS)
-> (BeforeJob -> String)
-> ([BeforeJob] -> ShowS)
-> Show BeforeJob
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BeforeJob -> ShowS
showsPrec :: Int -> BeforeJob -> ShowS
$cshow :: BeforeJob -> String
show :: BeforeJob -> String
$cshowList :: [BeforeJob] -> ShowS
showList :: [BeforeJob] -> ShowS
Show)

data JobDescr = JobDescr
    !BeforeJob
    {-# UNPACK #-} !JobID   -- ^ The ID of this job
    {-# UNPACK #-} !Int     -- ^ Number of IDs generated in this thread (i.e. last ID + 1)
  deriving (Int -> JobDescr -> ShowS
[JobDescr] -> ShowS
JobDescr -> String
(Int -> JobDescr -> ShowS)
-> (JobDescr -> String) -> ([JobDescr] -> ShowS) -> Show JobDescr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JobDescr -> ShowS
showsPrec :: Int -> JobDescr -> ShowS
$cshow :: JobDescr -> String
show :: JobDescr -> String
$cshowList :: [JobDescr] -> ShowS
showList :: [JobDescr] -> ShowS
Show)

-- | Isomorphic to: @ReaderT (IORef JobID) (StateT JobDescr (ContT () IO)) a@
--
-- The 'IORef' contains the next job ID to generate. The input 'JobDescr' is
-- the task so far; the output 'JobDescr' describes the terminal job of this
-- task (including its history).
newtype FwdM a = FwdM
    (IORef JobID
  -> JobDescr
  -> (JobDescr -> a -> IO ()) -> IO ())

instance Functor FwdM where
  fmap :: forall a b. (a -> b) -> FwdM a -> FwdM b
fmap a -> b
f (FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g) = (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
 -> FwdM b)
-> (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd JobDescr -> b -> IO ()
k ->
    IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g IORef JobID
jr JobDescr
jd ((JobDescr -> a -> IO ()) -> IO ())
-> (JobDescr -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd1 !a
x ->
      JobDescr -> b -> IO ()
k JobDescr
jd1 (a -> b
f a
x)

instance Applicative FwdM where
  pure :: forall a. a -> FwdM a
pure !a
x = (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM (\IORef JobID
_ !JobDescr
jd JobDescr -> a -> IO ()
k -> JobDescr -> a -> IO ()
k JobDescr
jd a
x)
  FwdM IORef JobID -> JobDescr -> (JobDescr -> (a -> b) -> IO ()) -> IO ()
f <*> :: forall a b. FwdM (a -> b) -> FwdM a -> FwdM b
<*> FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g = (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
 -> FwdM b)
-> (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd JobDescr -> b -> IO ()
k ->
    IORef JobID -> JobDescr -> (JobDescr -> (a -> b) -> IO ()) -> IO ()
f IORef JobID
jr JobDescr
jd ((JobDescr -> (a -> b) -> IO ()) -> IO ())
-> (JobDescr -> (a -> b) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd1 !a -> b
fun ->
      IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g IORef JobID
jr JobDescr
jd1 ((JobDescr -> a -> IO ()) -> IO ())
-> (JobDescr -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd2 !a
x ->
        JobDescr -> b -> IO ()
k JobDescr
jd2 (a -> b
fun a
x)

instance Monad FwdM where
  FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f >>= :: forall a b. FwdM a -> (a -> FwdM b) -> FwdM b
>>= a -> FwdM b
g = (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
 -> FwdM b)
-> (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd JobDescr -> b -> IO ()
k ->
    IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f IORef JobID
jr JobDescr
jd ((JobDescr -> a -> IO ()) -> IO ())
-> (JobDescr -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd1 !a
x ->
      let FwdM IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
h = a -> FwdM b
g a
x
      in IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
h IORef JobID
jr JobDescr
jd1 JobDescr -> b -> IO ()
k

instance MonadIO FwdM where
  liftIO :: forall a. IO a -> FwdM a
liftIO IO a
m = (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
 -> FwdM a)
-> (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
forall a b. (a -> b) -> a -> b
$ \IORef JobID
_ !JobDescr
jd JobDescr -> a -> IO ()
k -> IO a
m IO a -> (a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= JobDescr -> a -> IO ()
k JobDescr
jd

-- | 'pure' with a restricted type.
mpure :: a -> FwdM a
mpure :: forall a. a -> FwdM a
mpure = a -> FwdM a
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- Returns:
-- - the terminal job, i.e. the job with which the computation ended
-- - the maximal job ID generated plus 1
{-# NOINLINE runFwdM #-}
runFwdM :: FwdM a -> (JobDescr, JobID, a)
runFwdM :: forall a. FwdM a -> (JobDescr, JobID, a)
runFwdM (FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f) = IO (JobDescr, JobID, a) -> (JobDescr, JobID, a)
forall a. IO a -> a
unsafePerformIO (IO (JobDescr, JobID, a) -> (JobDescr, JobID, a))
-> IO (JobDescr, JobID, a) -> (JobDescr, JobID, a)
forall a b. (a -> b) -> a -> b
$ do
  String -> IO ()
evlog String
"fwdm start"
  IORef JobID
jiref <- JobID -> IO (IORef JobID)
forall a. a -> IO (IORef a)
newIORef (Int -> JobID
JobID Int
1)
  MVar (JobDescr, a)
resvar <- IO (MVar (JobDescr, a))
forall a. IO (MVar a)
newEmptyMVar
  IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f IORef JobID
jiref (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
Start (Int -> JobID
JobID Int
0) Int
0) (((JobDescr, a) -> IO ()) -> JobDescr -> a -> IO ()
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (MVar (JobDescr, a) -> (JobDescr, a) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (JobDescr, a)
resvar))
  (JobDescr
jd, a
y) <- MVar (JobDescr, a) -> IO (JobDescr, a)
forall a. MVar a -> IO a
takeMVar MVar (JobDescr, a)
resvar
  JobID
nextji <- IORef JobID -> IO JobID
forall a. IORef a -> IO a
readIORef IORef JobID
jiref
  String -> IO ()
evlog String
"fwdm end"
  (JobDescr, JobID, a) -> IO (JobDescr, JobID, a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (JobDescr
jd, JobID
nextji, a
y)

-- 'a' and 'b' are computed in new, separate jobs.
fwdmPar :: FwdM a -> FwdM b -> FwdM (a, b)
fwdmPar :: forall a b. FwdM a -> FwdM b -> FwdM (a, b)
fwdmPar (FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f1) (FwdM IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
f2) = (IORef JobID -> JobDescr -> (JobDescr -> (a, b) -> IO ()) -> IO ())
-> FwdM (a, b)
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID
  -> JobDescr -> (JobDescr -> (a, b) -> IO ()) -> IO ())
 -> FwdM (a, b))
-> (IORef JobID
    -> JobDescr -> (JobDescr -> (a, b) -> IO ()) -> IO ())
-> FwdM (a, b)
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd0 JobDescr -> (a, b) -> IO ()
k -> do
  (JobID
ji1, JobID
ji2, JobID
ji3) <-
    IORef JobID
-> (JobID -> (JobID, (JobID, JobID, JobID)))
-> IO (JobID, JobID, JobID)
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef JobID
jr (\(JobID Int
j) -> (Int -> JobID
JobID (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3), (Int -> JobID
JobID Int
j, Int -> JobID
JobID (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1), Int -> JobID
JobID (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2))))
  String -> IO ()
debug String
"! Starting fork"
  MVar (JobDescr, a)
cell1 <- IO (MVar (JobDescr, a))
forall a. IO (MVar a)
newEmptyMVar
  MVar (JobDescr, b)
cell2 <- IO (MVar (JobDescr, b))
forall a. IO (MVar a)
newEmptyMVar
  Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f1 IORef JobID
jr (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
Start JobID
ji1 Int
0)
                                 (\JobDescr
jd1 a
x -> String -> IO ()
debug String
"! join left" IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar (JobDescr, a) -> (JobDescr, a) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (JobDescr, a)
cell1 (JobDescr
jd1, a
x)))
  Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
f2 IORef JobID
jr (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
Start JobID
ji2 Int
0)
                                 (\JobDescr
jd2 b
y -> String -> IO ()
debug String
"! join right" IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar (JobDescr, b) -> (JobDescr, b) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (JobDescr, b)
cell2 (JobDescr
jd2, b
y)))
  ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
    (JobDescr
jd1, a
x) <- MVar (JobDescr, a) -> IO (JobDescr, a)
forall a. MVar a -> IO a
takeMVar MVar (JobDescr, a)
cell1
    (JobDescr
jd2, b
y) <- MVar (JobDescr, b) -> IO (JobDescr, b)
forall a. MVar a -> IO a
takeMVar MVar (JobDescr, b)
cell2
    String -> IO ()
debug String
"! Joined"
    Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> (a, b) -> IO ()
k (BeforeJob -> JobID -> Int -> JobDescr
JobDescr (JobDescr -> JobDescr -> JobDescr -> BeforeJob
Fork JobDescr
jd0 JobDescr
jd1 JobDescr
jd2) JobID
ji3 Int
0) (a
x, b
y))
  () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | The tag on a node in the Contrib graph. Consists of a job ID and the ID of
-- the node within this thread.
data NID = NID {-# UNPACK #-} !JobID
               {-# UNPACK #-} !Int
  deriving (Int -> NID -> ShowS
[NID] -> ShowS
NID -> String
(Int -> NID -> ShowS)
-> (NID -> String) -> ([NID] -> ShowS) -> Show NID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NID -> ShowS
showsPrec :: Int -> NID -> ShowS
$cshow :: NID -> String
show :: NID -> String
$cshowList :: [NID] -> ShowS
showList :: [NID] -> ShowS
Show)

fwdmGenId :: FwdM NID
fwdmGenId :: FwdM NID
fwdmGenId = (IORef JobID -> JobDescr -> (JobDescr -> NID -> IO ()) -> IO ())
-> FwdM NID
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> NID -> IO ()) -> IO ())
 -> FwdM NID)
-> (IORef JobID -> JobDescr -> (JobDescr -> NID -> IO ()) -> IO ())
-> FwdM NID
forall a b. (a -> b) -> a -> b
$ \IORef JobID
_ (JobDescr BeforeJob
prev JobID
ji Int
i) JobDescr -> NID -> IO ()
k -> JobDescr -> NID -> IO ()
k (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
prev JobID
ji (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) (JobID -> Int -> NID
NID JobID
ji Int
i)

fwdmGenIdInterleave :: FwdM Int
fwdmGenIdInterleave :: FwdM Int
fwdmGenIdInterleave = (IORef JobID -> JobDescr -> (JobDescr -> Int -> IO ()) -> IO ())
-> FwdM Int
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> Int -> IO ()) -> IO ())
 -> FwdM Int)
-> (IORef JobID -> JobDescr -> (JobDescr -> Int -> IO ()) -> IO ())
-> FwdM Int
forall a b. (a -> b) -> a -> b
$ \IORef JobID
_ (JobDescr BeforeJob
prev ji :: JobID
ji@(JobID Int
jiInt) Int
i) JobDescr -> Int -> IO ()
k ->
  if Int
jiInt Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    then JobDescr -> Int -> IO ()
k (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
prev JobID
ji (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) Int
i
    else String -> IO ()
forall a. HasCallStack => String -> a
error String
"fwdmGenIdInterleave: not on main thread"


-- ----------------------------------------------------------------------
-- The State type
-- ----------------------------------------------------------------------

-- | For each thread, the unzipped vector of (Contrib, adjoint) pairs. The
-- inner "vector" is Nothing while the items are still being allocated in
-- 'allocBS'.
type MaybeBuildState = MV.IOVector (Maybe (MV.IOVector Contrib, MVS.IOVector Double))
type BuildState = MV.IOVector (MV.IOVector Contrib, MVS.IOVector Double)

-- contrib edge: edge in the contribution graph
data CEdge = CEdge {-# UNPACK #-} !NID
                   !Contrib
                   {-# UNPACK #-} !Double

data Contrib
  = C0
  | C1 {-# UNPACK #-} !CEdge
  | C2 {-# UNPACK #-} !CEdge {-# UNPACK #-} !CEdge
  | C3 ![CEdge]

debugContrib :: Contrib -> String
debugContrib :: Contrib -> String
debugContrib Contrib
C0 = String
"Contrib []"
debugContrib (C1 CEdge
e) = String
"Contrib [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ CEdge -> String
debugCEdge CEdge
e String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
debugContrib (C2 CEdge
e1 CEdge
e2) = String
"Contrib [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ CEdge -> String
debugCEdge CEdge
e1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"," String -> ShowS
forall a. [a] -> [a] -> [a]
++ CEdge -> String
debugCEdge CEdge
e2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
debugContrib (C3 [CEdge]
l) = String
"Contrib [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"," ((CEdge -> String) -> [CEdge] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map CEdge -> String
debugCEdge [CEdge]
l) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"

debugCEdge :: CEdge -> String
debugCEdge :: CEdge -> String
debugCEdge (CEdge NID
nid Contrib
_ Double
d) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ NID -> String
forall a. Show a => a -> String
show NID
nid String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", <>, " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
d String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

-- TODO: This function is sequential in the total number of jobs started in the
-- forward pass, which is technically not good: they were started in parallel,
-- so we are technically sequentialising a bit in the worst case here. It's not
-- _so_ bad, though.
allocBS :: JobDescr -> MaybeBuildState -> IO ()
allocBS :: JobDescr -> MaybeBuildState -> IO ()
allocBS JobDescr
topjobdescr MaybeBuildState
threads = JobDescr -> IO ()
go JobDescr
topjobdescr
  where
    go :: JobDescr -> IO ()
    go :: JobDescr -> IO ()
go (JobDescr BeforeJob
prev (JobID Int
ji) Int
n) = do
      -- hPutStrLn stderr $ "allocBS: ji=" ++ show ji
      MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
-> Int -> IO (Maybe (IOVector Contrib, IOVector Double))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MaybeBuildState
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
threads Int
ji IO (Maybe (IOVector Contrib, IOVector Double))
-> (Maybe (IOVector Contrib, IOVector Double) -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (IOVector Contrib, IOVector Double)
_ -> String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"allocBS: already allocated? (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
ji String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
        Maybe (IOVector Contrib, IOVector Double)
Nothing -> do
          IOVector Contrib
cbarr <- Int -> Contrib -> IO (MVector (PrimState IO) Contrib)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
n Contrib
C0
          IOVector Double
adjarr <- Int -> Double -> IO (MVector (PrimState IO) Double)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> a -> m (MVector (PrimState m) a)
MVS.replicate Int
n Double
0.0
          MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
-> Int -> Maybe (IOVector Contrib, IOVector Double) -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MaybeBuildState
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
threads Int
ji ((IOVector Contrib, IOVector Double)
-> Maybe (IOVector Contrib, IOVector Double)
forall a. a -> Maybe a
Just (IOVector Contrib
cbarr, IOVector Double
adjarr))
      case BeforeJob
prev of
        BeforeJob
Start -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Fork JobDescr
parent JobDescr
jd1 JobDescr
jd2 -> do
          JobDescr -> IO ()
go JobDescr
jd1
          JobDescr -> IO ()
go JobDescr
jd2
          JobDescr -> IO ()
go JobDescr
parent

resolve :: JobDescr -> BuildState -> IO ()
resolve :: JobDescr -> BuildState -> IO ()
resolve JobDescr
terminalJob BuildState
threads = do
  MVar ()
cell <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
  JobDescr -> IO () -> IO ()
resolveTask JobDescr
terminalJob (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
cell ())
  MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
cell
  where
    resolveTask :: JobDescr -> IO () -> IO ()
    resolveTask :: JobDescr -> IO () -> IO ()
resolveTask (JobDescr BeforeJob
prev JobID
ji Int
i) IO ()
k = do
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
kDEBUG (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Enter job " String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show JobID
ji String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" from i=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
      (IOVector Contrib
cbarr, IOVector Double
adjarr) <- MVector (PrimState IO) (IOVector Contrib, IOVector Double)
-> Int -> IO (IOVector Contrib, IOVector Double)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read BuildState
MVector (PrimState IO) (IOVector Contrib, IOVector Double)
threads (let JobID Int
j = JobID
ji in Int
j)
      JobID -> Int -> IOVector Contrib -> IOVector Double -> IO ()
resolveJob JobID
ji (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) IOVector Contrib
cbarr IOVector Double
adjarr
      case BeforeJob
prev of
        BeforeJob
Start -> IO ()
k
        Fork JobDescr
jd0 JobDescr
jd1 JobDescr
jd2 -> do
          let jidOf :: JobDescr -> JobID
jidOf (JobDescr BeforeJob
_ JobID
n Int
_) = JobID
n
          Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
kDEBUG (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Forking jobs (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show (JobDescr -> JobID
jidOf JobDescr
jd1) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") and (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show (JobDescr -> JobID
jidOf JobDescr
jd2) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"); parent (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show (JobDescr -> JobID
jidOf JobDescr
jd0) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
          MVar ()
cell1 <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
          MVar ()
cell2 <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
          Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> IO () -> IO ()
resolveTask JobDescr
jd1 (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
cell1 ()))
          Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> IO () -> IO ()
resolveTask JobDescr
jd2 (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
cell2 ()))
          ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
            MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
cell1
            MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
cell2
            Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> IO () -> IO ()
resolveTask JobDescr
jd0 IO ()
k)
          () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    resolveJob :: JobID -> Int -> MV.IOVector Contrib -> MVS.IOVector Double -> IO ()
    resolveJob :: JobID -> Int -> IOVector Contrib -> IOVector Double -> IO ()
resolveJob JobID
_ (-1) IOVector Contrib
_ IOVector Double
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    resolveJob JobID
jid Int
i IOVector Contrib
cbarr IOVector Double
adjarr = do
      Contrib
cb <- MVector (PrimState IO) Contrib -> Int -> IO Contrib
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read IOVector Contrib
MVector (PrimState IO) Contrib
cbarr Int
i
      Double
adj <- MVector (PrimState IO) Double -> Int -> IO Double
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m a
MVS.read IOVector Double
MVector (PrimState IO) Double
adjarr Int
i
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
kDEBUG (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Apply (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ NID -> String
forall a. Show a => a -> String
show (JobID -> Int -> NID
NID JobID
jid Int
i) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") [adj=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
adj String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Contrib -> String
debugContrib Contrib
cb
      Contrib -> Double -> IO ()
apply Contrib
cb Double
adj
      JobID -> Int -> IOVector Contrib -> IOVector Double -> IO ()
resolveJob JobID
jid (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) IOVector Contrib
cbarr IOVector Double
adjarr

    apply :: Contrib -> Double -> IO ()
    apply :: Contrib -> Double -> IO ()
apply Contrib
C0 Double
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    apply (C1 CEdge
e) Double
a = Double -> CEdge -> IO ()
applyEdge Double
a CEdge
e
    apply (C2 CEdge
e1 CEdge
e2) Double
a = Double -> CEdge -> IO ()
applyEdge Double
a CEdge
e1 IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> CEdge -> IO ()
applyEdge Double
a CEdge
e2
    apply (C3 [CEdge]
l) Double
a = (CEdge -> IO ()) -> [CEdge] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Double -> CEdge -> IO ()
applyEdge Double
a) [CEdge]
l

    applyEdge :: Double -> CEdge -> IO ()
applyEdge Double
a (CEdge NID
nid Contrib
cb Double
d) = NID -> Contrib -> Double -> BuildState -> IO ()
addContrib NID
nid Contrib
cb (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
d) BuildState
threads

-- This is the function called from the backpropagator built by deinterleave,
-- as well as from 'resolve'.
addContrib :: NID -> Contrib -> Double -> BuildState -> IO ()
addContrib :: NID -> Contrib -> Double -> BuildState -> IO ()
addContrib (NID (JobID Int
ji) Int
i) Contrib
cb Double
adj BuildState
threads = do
  (IOVector Contrib
cbarr, IOVector Double
adjarr) <- MVector (PrimState IO) (IOVector Contrib, IOVector Double)
-> Int -> IO (IOVector Contrib, IOVector Double)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read BuildState
MVector (PrimState IO) (IOVector Contrib, IOVector Double)
threads Int
ji
  MVector (PrimState IO) Contrib -> Int -> Contrib -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write IOVector Contrib
MVector (PrimState IO) Contrib
cbarr Int
i Contrib
cb
  let loop :: Double -> IO ()
loop Double
acc = do
        (Bool
success, Double
old) <- IOVector Double -> Int -> Double -> Double -> IO (Bool, Double)
casIOVectorDouble IOVector Double
adjarr Int
i Double
acc (Double
acc Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
adj)
        if Bool
success then () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return () else Double -> IO ()
loop Double
old
  Double
orig <- MVector (PrimState IO) Double -> Int -> IO Double
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m a
MVS.read IOVector Double
MVector (PrimState IO) Double
adjarr Int
i
  Double -> IO ()
loop Double
orig
  -- hPutStrLn stderr $ "aC: [" ++ show ji ++ "] " ++ show i ++ ": " ++ show orig ++ " + " ++ show adj ++ " = " ++ show (orig + adj)

-- | Returns adjoints of the initial job: job ID 0.
{-# NOINLINE dualpass #-}
dualpass :: JobDescr -> JobID -> (d -> BuildState -> IO ()) -> d -> VS.Vector Double
dualpass :: forall d.
JobDescr
-> JobID -> (d -> BuildState -> IO ()) -> d -> Vector Double
dualpass JobDescr
finaljob (JobID Int
numjobs) d -> BuildState -> IO ()
backprop d
adj = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
  -- hPutStrLn stderr $ "\n-------------------- ENTERING DUALPASS --------------------"
  -- hPutStrLn stderr $ "dualpass: numjobs=" ++ show numjobs
  String -> IO ()
evlog String
"dual start"
  -- t1 <- getTime Monotonic
  -- hPutStrLn stderr $ "dual(numjobs=" ++ show numjobs ++ ")"
  MaybeBuildState
threads' <- Int
-> Maybe (IOVector Contrib, IOVector Double)
-> IO
     (MVector
        (PrimState IO) (Maybe (IOVector Contrib, IOVector Double)))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
numjobs Maybe (IOVector Contrib, IOVector Double)
forall a. Maybe a
Nothing
  JobDescr -> MaybeBuildState -> IO ()
allocBS JobDescr
finaljob MaybeBuildState
threads'
  BuildState
threads <- Int
-> (Int -> IO (IOVector Contrib, IOVector Double))
-> IO (MVector (PrimState IO) (IOVector Contrib, IOVector Double))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> (Int -> m a) -> m (MVector (PrimState m) a)
MV.generateM Int
numjobs (\Int
i ->
                MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
-> Int -> IO (Maybe (IOVector Contrib, IOVector Double))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MaybeBuildState
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
threads' Int
i IO (Maybe (IOVector Contrib, IOVector Double))
-> (Maybe (IOVector Contrib, IOVector Double)
    -> IO (IOVector Contrib, IOVector Double))
-> IO (IOVector Contrib, IOVector Double)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                  Just (IOVector Contrib, IOVector Double)
p -> (IOVector Contrib, IOVector Double)
-> IO (IOVector Contrib, IOVector Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOVector Contrib, IOVector Double)
p
                  Maybe (IOVector Contrib, IOVector Double)
Nothing -> String -> IO (IOVector Contrib, IOVector Double)
forall a. HasCallStack => String -> a
error (String -> IO (IOVector Contrib, IOVector Double))
-> String -> IO (IOVector Contrib, IOVector Double)
forall a b. (a -> b) -> a -> b
$ String
"Thread array not initialised: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)
  String -> IO ()
evlog String
"dual allocated"
  -- t2 <- getTime Monotonic
  -- hPutStrLn stderr $ "prev = " ++ show prev
  -- hPutStrLn stderr $ "iout = " ++ show iout
  -- hPutStrLn stderr $ "adj = " ++ show adj
  d -> BuildState -> IO ()
backprop d
adj BuildState
threads
  -- hPutStrLn stderr $ "-- resolve --"
  String -> IO ()
evlog String
"dual bped"
  -- t3 <- getTime Monotonic
  JobDescr -> BuildState -> IO ()
resolve JobDescr
finaljob BuildState
threads
  String -> IO ()
evlog String
"dual resolved"
  -- t4 <- getTime Monotonic
  IOVector Double
adjarr0 <- (IOVector Contrib, IOVector Double) -> IOVector Double
forall a b. (a, b) -> b
snd ((IOVector Contrib, IOVector Double) -> IOVector Double)
-> IO (IOVector Contrib, IOVector Double) -> IO (IOVector Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) (IOVector Contrib, IOVector Double)
-> Int -> IO (IOVector Contrib, IOVector Double)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read BuildState
MVector (PrimState IO) (IOVector Contrib, IOVector Double)
threads Int
0
  -- hPutStrLn stderr $ "\n//////////////////// FINISHED DUALPASS ////////////////////"
  Vector Double
res <- MVector (PrimState IO) Double -> IO (Vector Double)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.freeze IOVector Double
MVector (PrimState IO) Double
adjarr0
  String -> IO ()
evlog String
"dual frozen"
  -- t5 <- getTime Monotonic
  -- let tms = [t1, t2, t3, t4, t5]
  -- hPutStrLn stderr $ "dual: " ++ intercalate " / " (zipWith (\t t' -> show (fromIntegral (toNanoSecs (diffTimeSpec t t')) / 1e6 :: Double) ++ "ms") tms (tail tms))
  Vector Double -> IO (Vector Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
res


-- ------------------------------------------------------------
-- Structure descriptions of types
-- ------------------------------------------------------------

-- | The structure of a type, as used by the AD transformation. Use
-- 'structureFromTypeable' or 'structureFromType' to construct a 'Structure'.
data Structure = Structure MonoType DataTypes
  deriving (Int -> Structure -> ShowS
[Structure] -> ShowS
Structure -> String
(Int -> Structure -> ShowS)
-> (Structure -> String)
-> ([Structure] -> ShowS)
-> Show Structure
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Structure -> ShowS
showsPrec :: Int -> Structure -> ShowS
$cshow :: Structure -> String
show :: Structure -> String
$cshowList :: [Structure] -> ShowS
showList :: [Structure] -> ShowS
Show)

instance NFData Structure where
  rnf :: Structure -> ()
rnf (Structure SimpleType 'Mono
m DataTypes
ds) = SimpleType 'Mono -> ()
forall a. NFData a => a -> ()
rnf SimpleType 'Mono
m () -> () -> ()
forall a b. a -> b -> b
`seq` String -> ()
forall a. NFData a => a -> ()
rnf (DataTypes -> String
forall a. Show a => a -> String
show DataTypes
ds)

-- | Analyse the 'Type' and give a 'Structure' that describes the type for use
-- in 'reverseAD''.
structureFromType :: Q Type -> Q Structure
structureFromType :: Q Type -> Q Structure
structureFromType Q Type
ty = Q Type
ty Q Type -> (Type -> Q Structure) -> Q Structure
forall a b. Q a -> (a -> Q b) -> Q b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((SimpleType 'Mono, DataTypes) -> Structure)
-> Q (SimpleType 'Mono, DataTypes) -> Q Structure
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SimpleType 'Mono -> DataTypes -> Structure)
-> (SimpleType 'Mono, DataTypes) -> Structure
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SimpleType 'Mono -> DataTypes -> Structure
Structure) (Q (SimpleType 'Mono, DataTypes) -> Q Structure)
-> (Type -> Q (SimpleType 'Mono, DataTypes)) -> Type -> Q Structure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Q (SimpleType 'Mono, DataTypes)
exploreRecursiveType

-- | A 'TypeRep' (which we can obtain from the 'Typeable' constraint) can be
-- used to construct a 'Type' for the type @a@, on which we can then call
-- 'structureFromType'.
structureFromTypeable :: Typeable a => Proxy a -> Q Structure
structureFromTypeable :: forall a. Typeable a => Proxy a -> Q Structure
structureFromTypeable = Q Type -> Q Structure
structureFromType (Q Type -> Q Structure)
-> (Proxy a -> Q Type) -> Proxy a -> Q Structure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeRep -> Q Type
typeRepToType (TypeRep -> Q Type) -> (Proxy a -> TypeRep) -> Proxy a -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy a -> TypeRep
forall {k} (proxy :: k -> *) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep
  where
    -- Taken from th-utilities-0.2.4.3 by Michael Sloan
    typeRepToType :: TypeRep -> Q Type
    typeRepToType :: TypeRep -> Q Type
typeRepToType TypeRep
tr = do
      let (TyCon
con, [TypeRep]
args) = TypeRep -> (TyCon, [TypeRep])
splitTyConApp TypeRep
tr
          name :: Name
name = OccName -> NameFlavour -> Name
Name (String -> OccName
OccName (TyCon -> String
tyConName TyCon
con)) (NameSpace -> PkgName -> ModName -> NameFlavour
NameG NameSpace
TcClsName (String -> PkgName
PkgName (TyCon -> String
tyConPackage TyCon
con)) (String -> ModName
ModName (TyCon -> String
tyConModule TyCon
con)))
      [Type]
resultArgs <- (TypeRep -> Q Type) -> [TypeRep] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TypeRep -> Q Type
typeRepToType [TypeRep]
args
      Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
name) [Type]
resultArgs)

data Morphic = Poly | Mono
  deriving (Int -> Morphic -> ShowS
[Morphic] -> ShowS
Morphic -> String
(Int -> Morphic -> ShowS)
-> (Morphic -> String) -> ([Morphic] -> ShowS) -> Show Morphic
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Morphic -> ShowS
showsPrec :: Int -> Morphic -> ShowS
$cshow :: Morphic -> String
show :: Morphic -> String
$cshowList :: [Morphic] -> ShowS
showList :: [Morphic] -> ShowS
Show)

data SimpleType morphic where
  DiscreteST :: SimpleType mf
  ScalarST :: SimpleType mf
  VarST :: Name -> SimpleType 'Poly
  ConST :: Name -> [SimpleType mf] -> SimpleType mf
deriving instance Show (SimpleType mf)
deriving instance Eq (SimpleType mf)
deriving instance Ord (SimpleType mf)
deriving instance Lift (SimpleType mf)

type PolyType = SimpleType 'Poly
type MonoType = SimpleType 'Mono

instance NFData (SimpleType mf) where
  rnf :: SimpleType mf -> ()
rnf SimpleType mf
DiscreteST = ()
  rnf SimpleType mf
ScalarST = ()
  rnf (VarST Name
n) = Name
n Name -> () -> ()
forall a b. a -> b -> b
`seq` ()
  rnf (ConST Name
n [SimpleType mf]
ts) = Name
n Name -> () -> ()
forall a b. a -> b -> b
`seq` [SimpleType mf] -> ()
forall a. NFData a => a -> ()
rnf [SimpleType mf]
ts

discreteTypeNames :: [Name]
discreteTypeNames :: [Name]
discreteTypeNames =
  [''Int, ''Int8, ''Int16, ''Int32, ''Int64
  ,''Word, ''Word8, ''Word16, ''Word32, ''Word64
  ,''Char, ''Bool]

summariseType :: MonadFail m => Type -> m PolyType
summariseType :: forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType = \case
  ConT Name
n
    | Name
n Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
discreteTypeNames
    -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return PolyType
forall (mf :: Morphic). SimpleType mf
DiscreteST
    | Name
n Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Double -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return PolyType
forall (mf :: Morphic). SimpleType mf
ScalarST
    | Name
n Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Float -> String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Only Double is an active type for now (Float isn't)"
    | Bool
otherwise -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST Name
n []
  VarT Name
n -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> PolyType
VarST Name
n
  ParensT Type
t -> Type -> m PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
t
  TupleT Int
k -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST (Int -> Name
tupleTypeName Int
k) []
  Type
ListT -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST ''[] []
  t :: Type
t@AppT{} -> do
    let (Type
hd, [Type]
args) = Type -> (Type, [Type])
collectApps Type
t
    PolyType
hd' <- Type -> m PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
hd
    [PolyType]
args' <- (Type -> m PolyType) -> [Type] -> m [PolyType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> m PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType [Type]
args
    PolyType -> [PolyType] -> m PolyType
forall (m :: * -> *).
MonadFail m =>
PolyType -> [PolyType] -> m PolyType
smartAppsST PolyType
hd' [PolyType]
args'
  Type
t -> String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"Unsupported type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Ppr a => a -> String
pprint Type
t
  where
    smartAppsST :: MonadFail m => PolyType -> [PolyType] -> m PolyType
    smartAppsST :: forall (m :: * -> *).
MonadFail m =>
PolyType -> [PolyType] -> m PolyType
smartAppsST PolyType
DiscreteST [PolyType]
_ = String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"Discrete type does not take type parameters"
    smartAppsST PolyType
ScalarST [PolyType]
_ = String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"'Double' does not take type parameters"
    smartAppsST (VarST Name
n) [PolyType]
_ = String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"Higher-rank type variable not supported in reverse AD: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n
    smartAppsST (ConST Name
n [PolyType]
as) [PolyType]
bs = PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST Name
n ([PolyType]
as [PolyType] -> [PolyType] -> [PolyType]
forall a. [a] -> [a] -> [a]
++ [PolyType]
bs)

-- | Given an inlining function that returns the value of a type /variable/,
-- monomorphise the type.
toMonoType :: Applicative f => (Name -> f MonoType) -> PolyType -> f MonoType
toMonoType :: forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType Name -> f (SimpleType 'Mono)
_ PolyType
DiscreteST = SimpleType 'Mono -> f (SimpleType 'Mono)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SimpleType 'Mono
forall (mf :: Morphic). SimpleType mf
DiscreteST
toMonoType Name -> f (SimpleType 'Mono)
_ PolyType
ScalarST = SimpleType 'Mono -> f (SimpleType 'Mono)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SimpleType 'Mono
forall (mf :: Morphic). SimpleType mf
ScalarST
toMonoType Name -> f (SimpleType 'Mono)
f (VarST Name
n) = Name -> f (SimpleType 'Mono)
f Name
n
toMonoType Name -> f (SimpleType 'Mono)
f (ConST Name
n [PolyType]
ts) = Name -> [SimpleType 'Mono] -> SimpleType 'Mono
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST Name
n ([SimpleType 'Mono] -> SimpleType 'Mono)
-> f [SimpleType 'Mono] -> f (SimpleType 'Mono)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PolyType -> f (SimpleType 'Mono))
-> [PolyType] -> f [SimpleType 'Mono]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType Name -> f (SimpleType 'Mono)
f) [PolyType]
ts

-- Map from (type name, type arguments) to [(constructor name, fields)]
type DataTypes = Map (Name, [MonoType]) [(Name, [MonoType])]

-- | Given:
-- - The stack of types in the current exploration branch (initialise with
--   'mempty'), giving for each type name its argument instantiation
-- - The current type synonym expansion stack
-- - The monotype to explore
-- returns the transitive closure of datatypes included in the input monotype.
exploreType :: Map Name [MonoType] -> Set Name -> MonoType -> Q DataTypes
exploreType :: Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
stack Set Name
tysynstack
  | Map Name [SimpleType 'Mono] -> Int
forall k a. Map k a -> Int
Map.size Map Name [SimpleType 'Mono]
stack Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
20 = \SimpleType 'Mono
_ -> String -> Q DataTypes
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Very deep data type hierarchy (depth 20); polymorphic recursion?"
  | Set Name -> Int
forall a. Set a -> Int
Set.size Set Name
tysynstack Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
20 = \SimpleType 'Mono
_ -> String -> Q DataTypes
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Very deep type synonym hierarchy (depth 20); self-recursive type synonym?"
exploreType Map Name [SimpleType 'Mono]
stack Set Name
tysynstack = \case
  SimpleType 'Mono
DiscreteST -> DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return DataTypes
forall a. Monoid a => a
mempty
  SimpleType 'Mono
ScalarST -> DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return DataTypes
forall a. Monoid a => a
mempty
  ConST Name
tyname [SimpleType 'Mono]
argtys
    | Just [SimpleType 'Mono]
prevargtys <- Name -> Map Name [SimpleType 'Mono] -> Maybe [SimpleType 'Mono]
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
tyname Map Name [SimpleType 'Mono]
stack
    , [SimpleType 'Mono]
argtys [SimpleType 'Mono] -> [SimpleType 'Mono] -> Bool
forall a. Eq a => a -> a -> Bool
== [SimpleType 'Mono]
prevargtys -> DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return DataTypes
forall a. Monoid a => a
mempty
        -- if argtys == prevargtys
        --   then return mempty  -- regular recursion, don't need to expand again
        --   else fail $ "Polymorphic recursion (data type that contains itself \
        --               \with different type argument instantiations) is not \
        --               \supported in reverse AD.\n\
        --               \Type constructor: " ++ show tyname ++ "\n\
        --               \Previously seen: " ++ show prevargtys ++ "\n\
        --               \Current:         " ++ show argtys
    | Bool
otherwise -> do
        Dec
typedecl <- Name -> Q Info
reify Name
tyname Q Info -> (Info -> Q Dec) -> Q Dec
forall a b. Q a -> (a -> Q b) -> Q b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          TyConI Dec
decl -> Dec -> Q Dec
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
decl
          Info
info -> String -> Q Dec
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Dec) -> String -> Q Dec
forall a b. (a -> b) -> a -> b
$ String
"Name " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a lifted type name: "
                         String -> ShowS
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Show a => a -> String
show Info
info
        let analyseConstructor :: [Name] -> Con -> Q ((Name, [SimpleType 'Mono]), DataTypes)
analyseConstructor [Name]
tyvars Con
constr
              | [Name] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
tyvars Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
argtys = do
                  (Name
conname, [Type]
fieldtys) <- case Con
constr of
                    NormalC Name
conname [BangType]
fieldtys -> (Name, [Type]) -> Q (Name, [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
conname, (BangType -> Type) -> [BangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\(  Bang
_,Type
ty) -> Type
ty) [BangType]
fieldtys)
                    RecC    Name
conname [VarBangType]
fieldtys -> (Name, [Type]) -> Q (Name, [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
conname, (VarBangType -> Type) -> [VarBangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\(Name
_,Bang
_,Type
ty) -> Type
ty) [VarBangType]
fieldtys)
                    InfixC (Bang
_, Type
ty1) Name
conname (Bang
_, Type
ty2) -> (Name, [Type]) -> Q (Name, [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
conname, [Type
ty1, Type
ty2])
                    Con
_ -> String -> Q (Name, [Type])
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (Name, [Type])) -> String -> Q (Name, [Type])
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor format on data: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
constr
                  [PolyType]
fieldtys' <- (Type -> Q PolyType) -> [Type] -> Q [PolyType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Q PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType [Type]
fieldtys
                  [SimpleType 'Mono]
fieldtys'' <- (PolyType -> Q (SimpleType 'Mono))
-> [PolyType] -> Q [SimpleType 'Mono]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name
-> [Name] -> [SimpleType 'Mono] -> PolyType -> Q (SimpleType 'Mono)
monomorphiseField Name
tyname [Name]
tyvars [SimpleType 'Mono]
argtys) [PolyType]
fieldtys'
                  let stack' :: Map Name [SimpleType 'Mono]
stack' = Name
-> [SimpleType 'Mono]
-> Map Name [SimpleType 'Mono]
-> Map Name [SimpleType 'Mono]
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Name
tyname [SimpleType 'Mono]
argtys Map Name [SimpleType 'Mono]
stack
                  [DataTypes]
typesets <- (SimpleType 'Mono -> Q DataTypes)
-> [SimpleType 'Mono] -> Q [DataTypes]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
stack' Set Name
forall a. Monoid a => a
mempty) [SimpleType 'Mono]
fieldtys''
                  ((Name, [SimpleType 'Mono]), DataTypes)
-> Q ((Name, [SimpleType 'Mono]), DataTypes)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Name
conname, [SimpleType 'Mono]
fieldtys''), ((Name, [SimpleType 'Mono])
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])])
-> [DataTypes] -> DataTypes
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual [DataTypes]
typesets)
              | Bool
otherwise = String -> Q ((Name, [SimpleType 'Mono]), DataTypes)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ((Name, [SimpleType 'Mono]), DataTypes))
-> String -> Q ((Name, [SimpleType 'Mono]), DataTypes)
forall a b. (a -> b) -> a -> b
$ String
"Type not fully applied: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname
        case Dec
typedecl of
          NewtypeD [] Name
_ [TyVarBndr ()]
tyvars Maybe Type
_ Con
constr  [DerivClause]
_ -> do
            ((Name, [SimpleType 'Mono])
condescr, DataTypes
typeset) <- [Name] -> Con -> Q ((Name, [SimpleType 'Mono]), DataTypes)
analyseConstructor ((TyVarBndr () -> Name) -> [TyVarBndr ()] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr () -> Name
tvbName [TyVarBndr ()]
tyvars) Con
constr
            DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (((Name, [SimpleType 'Mono])
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])])
-> (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> DataTypes
-> DataTypes
forall k a.
Ord k =>
(k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual (Name
tyname, [SimpleType 'Mono]
argtys) [(Name, [SimpleType 'Mono])
condescr] DataTypes
typeset)
          DataD    [] Name
_ [TyVarBndr ()]
tyvars Maybe Type
_ [Con]
constrs [DerivClause]
_ -> do
            ([(Name, [SimpleType 'Mono])]
condescrs, [DataTypes]
typesets) <- [((Name, [SimpleType 'Mono]), DataTypes)]
-> ([(Name, [SimpleType 'Mono])], [DataTypes])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((Name, [SimpleType 'Mono]), DataTypes)]
 -> ([(Name, [SimpleType 'Mono])], [DataTypes]))
-> Q [((Name, [SimpleType 'Mono]), DataTypes)]
-> Q ([(Name, [SimpleType 'Mono])], [DataTypes])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Con -> Q ((Name, [SimpleType 'Mono]), DataTypes))
-> [Con] -> Q [((Name, [SimpleType 'Mono]), DataTypes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Name] -> Con -> Q ((Name, [SimpleType 'Mono]), DataTypes)
analyseConstructor ((TyVarBndr () -> Name) -> [TyVarBndr ()] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr () -> Name
tvbName [TyVarBndr ()]
tyvars)) [Con]
constrs
            let typeset :: DataTypes
typeset = ((Name, [SimpleType 'Mono])
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])])
-> [DataTypes] -> DataTypes
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual [DataTypes]
typesets
            DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (((Name, [SimpleType 'Mono])
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])]
 -> [(Name, [SimpleType 'Mono])])
-> (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> DataTypes
-> DataTypes
forall k a.
Ord k =>
(k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual (Name
tyname, [SimpleType 'Mono]
argtys) [(Name, [SimpleType 'Mono])]
condescrs DataTypes
typeset)
          TySynD Name
_ [TyVarBndr ()]
tyvars Type
rhs -> do
            -- when (tyname `Set.member` tysynstack) $
            --   fail $ "Infinite type synonym recursion in " ++ show tyname
            PolyType
srhs <- Type -> Q PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
rhs
            SimpleType 'Mono
mrhs <- Name
-> [Name] -> [SimpleType 'Mono] -> PolyType -> Q (SimpleType 'Mono)
monomorphiseField Name
tyname ((TyVarBndr () -> Name) -> [TyVarBndr ()] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr () -> Name
tvbName [TyVarBndr ()]
tyvars) [SimpleType 'Mono]
argtys PolyType
srhs
            Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
stack (Name -> Set Name -> Set Name
forall a. Ord a => a -> Set a -> Set a
Set.insert Name
tyname Set Name
tysynstack) SimpleType 'Mono
mrhs
          Dec
_ -> String -> Q DataTypes
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q DataTypes) -> String -> Q DataTypes
forall a b. (a -> b) -> a -> b
$ String
"Type not supported: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" (not simple \
                      \newtype or data)"
  where
    monomorphiseField :: Name -> [Name] -> [MonoType] -> PolyType -> Q MonoType
    monomorphiseField :: Name
-> [Name] -> [SimpleType 'Mono] -> PolyType -> Q (SimpleType 'Mono)
monomorphiseField Name
tyname [Name]
typarams [SimpleType 'Mono]
argtys =
      (Name -> Q (SimpleType 'Mono)) -> PolyType -> Q (SimpleType 'Mono)
forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType ((Name -> Q (SimpleType 'Mono))
 -> PolyType -> Q (SimpleType 'Mono))
-> (Name -> Q (SimpleType 'Mono))
-> PolyType
-> Q (SimpleType 'Mono)
forall a b. (a -> b) -> a -> b
$ \Name
n -> case Name -> [(Name, SimpleType 'Mono)] -> Maybe (SimpleType 'Mono)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
n ([Name] -> [SimpleType 'Mono] -> [(Name, SimpleType 'Mono)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
typarams [SimpleType 'Mono]
argtys) of
        Just SimpleType 'Mono
mt -> SimpleType 'Mono -> Q (SimpleType 'Mono)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return SimpleType 'Mono
mt
        Maybe (SimpleType 'Mono)
Nothing -> String -> Q (SimpleType 'Mono)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (SimpleType 'Mono)) -> String -> Q (SimpleType 'Mono)
forall a b. (a -> b) -> a -> b
$ String
"Type variable out of scope in definition of \
                          \data type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n

    mergeIfEqual :: a -> a -> a -> a
mergeIfEqual a
key a
v1 a
v2
      | a
v1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v2 = a
v1
      | Bool
otherwise = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"Conflicting explorations of the same data type!\n\
                            \Key: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
key String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n\
                            \Val 1: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
v1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n\
                            \Val 2: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
v2

exploreRecursiveType :: Type -> Q (MonoType, DataTypes)
exploreRecursiveType :: Type -> Q (SimpleType 'Mono, DataTypes)
exploreRecursiveType Type
tau = do
  PolyType
sty <- Type -> Q PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
tau
  SimpleType 'Mono
mty <- (Name -> Q (SimpleType 'Mono)) -> PolyType -> Q (SimpleType 'Mono)
forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType (\Name
n -> String -> Q (SimpleType 'Mono)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (SimpleType 'Mono)) -> String -> Q (SimpleType 'Mono)
forall a b. (a -> b) -> a -> b
$ String
"Reverse AD input or output type is polymorphic \
                                  \(contains type variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")")
                    PolyType
sty
  DataTypes
dtypes <- Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
forall a. Monoid a => a
mempty Set Name
forall a. Monoid a => a
mempty SimpleType 'Mono
mty
  (SimpleType 'Mono, DataTypes) -> Q (SimpleType 'Mono, DataTypes)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (SimpleType 'Mono
mty, DataTypes
dtypes)


-- ----------------------------------------------------------------------
-- Top-level interface to reverse AD
-- ----------------------------------------------------------------------

-- | Use as follows:
--
-- > > :t $$(reverseAD @_ @Double [|| \(x, y) -> x * y ||])
-- > (Double, Double) -> (Double, Double -> (Double, Double))
--
-- The active scalar type is 'Double'. 'Double' values are differentiated; 'Float' is currently unsupported.
--
-- Note that due to the GHC stage restriction, any data types used in @a@ and
-- @b@ must be defined in a separate module that is then imported into the
-- module where you use 'reverseAD'. If you get a GHC panic about @$tcFoo@ not
-- being in scope (where @Foo@ is your data type), this means that you violated
-- this stage restriction. See
-- [GHC #21547](https://gitlab.haskell.org/ghc/ghc/-/issues/21547).
reverseAD :: forall a b. (Typeable a, Typeable b)
          => Code Q (a -> b)
          -> Code Q (a -> (b, b -> a))
reverseAD :: forall a b.
(Typeable a, Typeable b) =>
Code Q (a -> b) -> Code Q (a -> (b, b -> a))
reverseAD = Q Structure
-> Q Structure -> Code Q (a -> b) -> Code Q (a -> (b, b -> a))
forall a b.
Q Structure
-> Q Structure -> Code Q (a -> b) -> Code Q (a -> (b, b -> a))
reverseAD' (Proxy a -> Q Structure
forall a. Typeable a => Proxy a -> Q Structure
structureFromTypeable (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a)) (Proxy b -> Q Structure
forall a. Typeable a => Proxy a -> Q Structure
structureFromTypeable (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @b))

-- | Same as 'reverseAD', but with user-supplied 'Structure's.
reverseAD' :: forall a b.
              Q Structure  -- ^ Structure of @a@
           -> Q Structure  -- ^ Structure of @b@
           -> Code Q (a -> b)
           -> Code Q (a -> (b, b -> a))
reverseAD' :: forall a b.
Q Structure
-> Q Structure -> Code Q (a -> b) -> Code Q (a -> (b, b -> a))
reverseAD' Q Structure
inpStruc Q Structure
outStruc (Code Q (a -> b) -> Q Exp
forall a (m :: * -> *). Quote m => Code m a -> m Exp
unTypeCode -> Q Exp
inputCode) =
  Q Exp -> Code Q (a -> (b, b -> a))
forall a (m :: * -> *). Quote m => m Exp -> Code m a
unsafeCodeCoerce (Q Exp -> Code Q (a -> (b, b -> a)))
-> Q Exp -> Code Q (a -> (b, b -> a))
forall a b. (a -> b) -> a -> b
$ do
    Structure
inpStruc' <- Q Structure
inpStruc
    Structure
outStruc' <- Q Structure
outStruc
    Exp
ex <- Q Exp
inputCode
    Structure -> Structure -> Exp -> Q Exp
transform Structure
inpStruc' Structure
outStruc' Exp
ex

transform :: Structure -> Structure -> TH.Exp -> Q TH.Exp
transform :: Structure -> Structure -> Exp -> Q Exp
transform Structure
inpStruc Structure
outStruc Exp
expr = do
  Exp
expr' <- Exp -> Q Exp
translate Exp
expr
  case Exp
expr' of
    ELam Pat
pat Exp
body -> Structure -> Structure -> Pat -> Exp -> Q Exp
transform' Structure
inpStruc Structure
outStruc Pat
pat Exp
body
    Exp
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"Top-level expression in reverseAD must be lambda, but is: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Show a => a -> String
show Exp
expr'

transform' :: Structure -> Structure -> Pat -> Source.Exp -> Q TH.Exp
transform' :: Structure -> Structure -> Pat -> Exp -> Q Exp
transform' Structure
inpStruc Structure
outStruc Pat
pat Exp
expr = do
  Structure
_ <- IO Structure -> Q Structure
forall a. IO a -> Q a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Structure -> Q Structure) -> IO Structure -> Q Structure
forall a b. (a -> b) -> a -> b
$ Structure -> IO Structure
forall a. a -> IO a
evaluate (Structure -> Structure
forall a. NFData a => a -> a
force Structure
inpStruc)
  Structure
_ <- IO Structure -> Q Structure
forall a. IO a -> Q a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Structure -> Q Structure) -> IO Structure -> Q Structure
forall a b. (a -> b) -> a -> b
$ Structure -> IO Structure
forall a. a -> IO a
evaluate (Structure -> Structure
forall a. NFData a => a -> a
force Structure
outStruc)
  Set Name
patbound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
  Name
argvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"arg"
  Name
rebuildvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"rebuild"
  Name
rebuild'var <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"rebuild'"
  Name
outvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"out"
  Name
outjdvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"outjd"
  Name
outnextjivar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"outnextji"
  Name
primalvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"primal"
  Name
primal'var <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"primal'"
  Name
dualvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"dual"
  Name
dual'var <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"dual'"
  Name
adjvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"adjoint"
  [| \ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
argvar) ->
        let ($(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
outjdvar), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
outnextjivar), ($(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
rebuildvar), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
primalvar), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
dualvar))) =
              runFwdM $ do
                ($(Pat -> Q Pat
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Pat
pat), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
rebuild'var)) <- $(Structure -> Exp -> Q Exp
forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
interleave Structure
inpStruc (Name -> Exp
VarE Name
argvar))
                $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
outvar) <- $(Set Name -> Exp -> Q Exp
ddr Set Name
patbound Exp
expr)
                ($(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
primal'var), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
dual'var)) <- mpure $(Structure -> Exp -> Q Exp
forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
deinterleave Structure
outStruc (Name -> Exp
VarE Name
outvar))
                -- liftIO $ debug "evaluate start"
                -- _ <- return $! $(varE primal'var)
                -- _ <- return $! $(varE dual'var)
                -- liftIO $ debug "evaluate done"
                mpure ($(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
rebuild'var), $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
primal'var), $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
dual'var))
        in ($(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
primalvar)
           ,\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
adjvar) ->
                $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
rebuildvar) $! dualpass $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
outjdvar) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
outnextjivar) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
dualvar) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
adjvar)) |]


-- ----------------------------------------------------------------------
-- The compositional code transformation
-- ----------------------------------------------------------------------

-- "Dual" number
data DN = DN {-# UNPACK #-} !Double
             {-# UNPACK #-} !NID
             !Contrib

-- Set of names bound in the program at this point
type Env = Set Name

-- Γ |- t : a                     -- expression
-- ~> Dt[Γ] |- D[t] : FwdM Dt[a]  -- result
ddr :: Env -> Source.Exp -> Q TH.Exp
ddr :: Set Name -> Exp -> Q Exp
ddr Set Name
env = \case
  EVar Name
name
    | Name
name Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Name
env -> [| mpure $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'fromIntegral -> [| mpure fromIntegralOp |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'negate -> do
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) -> applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) negate (\_ -> (-1))) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'sqrt -> do
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
        [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
              applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) sqrt (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> 1 / (2 * sqrt $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname)))) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'exp -> do
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Name
primalname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
        [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
              applyUnaryOp2 $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) exp (\_ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
primalname) -> $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
primalname))) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'log -> do
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
        [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
              applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) log (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> 1 / $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname))) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'sin -> do
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
        [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
              applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) sin (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> cos $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname))) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'cos -> do
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
        [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
              applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) cos (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> - sin $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname))) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '($) -> do
        Name
fname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"f"
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Set Name -> Exp -> Q Exp
ddr Set Name
env (Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
fname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
xname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
EVar Name
fname Exp -> Exp -> Exp
`EApp` Name -> Exp
EVar Name
xname)
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(.) -> do
        Name
fname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"f"
        Name
gname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"g"
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Set Name -> Exp -> Q Exp
ddr Set Name
env (Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
fname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
gname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
xname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
EVar Name
fname Exp -> Exp -> Exp
`EApp` (Name -> Exp
EVar Name
gname Exp -> Exp -> Exp
`EApp` Name -> Exp
EVar Name
xname))
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(+) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(+) (Bool
False, Bool
False) (\Q Exp
_ Q Exp
_ -> [| (1, 1) |])
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(-) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(-) (Bool
False, Bool
False) (\Q Exp
_ Q Exp
_ -> [| (1, -1) |])
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(*) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(*) (Bool
True, Bool
True) (\Q Exp
x Q Exp
y -> [| ($Q Exp
y, $Q Exp
x) |])
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(/) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(/) (Bool
True, Bool
True) (\Q Exp
x Q Exp
y -> [| (recip $Q Exp
y, (- $Q Exp
x) / ($Q Exp
y * $Q Exp
y)) |])
    | Name
name Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ['(==), '(/=), '(<), '(>), '(<=), '(>=)] -> do
        Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
        Name
yname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"y"
        [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
             mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
yname) ->
               mpure (applyCmpOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
yname)
                                 $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name)))) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(|*|) ->
        String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"The parallel combinator (|*|) should be applied directly; partially applying it is pointless."
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'error -> [| mpure error |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'fst -> [| mpure (\x -> mpure (fst x)) |]
    | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'snd -> [| mpure (\x -> mpure (snd x)) |]
    | Bool
otherwise -> do
        Type
typ <- Name -> Q Type
reifyType Name
name
        let ([Type]
params, Type
retty) = Type -> ([Type], Type)
unpackFunctionType Type
typ
        if (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
isDiscrete [Type]
params Bool -> Bool -> Bool
&& Type -> Bool
isDiscrete Type
retty
          then [| mpure $(Int -> Exp -> Q Exp
liftKleisliN ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
params) (Name -> Exp
VarE Name
name)) |]
          else String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"Most free variables not supported in reverseAD: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
name String -> ShowS
forall a. [a] -> [a] -> [a]
++
                      String
" (env = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Set Name -> String
forall a. Show a => a -> String
show Set Name
env String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

  ECon Name
name -> do
    [Type]
fieldtys <- Name -> Q [Type]
checkDatacon Name
name
    [| mpure $(Int -> Exp -> Q Exp
liftKleisliN ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
fieldtys) (Name -> Exp
ConE Name
name)) |]

  ELit Lit
lit -> case Lit
lit of
    RationalL Rational
_ -> do
      Name
iname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"i"
      [| do $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
iname) <- fwdmGenId
            mpure (DN $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE Lit
lit) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
iname) C0) |]
    FloatPrimL Rational
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"float prim?"
    DoublePrimL Rational
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"double prim?"
    IntegerL Integer
_ -> [| mpure $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE Lit
lit) |]
    StringL String
_ -> [| mpure $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE Lit
lit) |]
    Lit
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"literal? " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Lit -> String
forall a. Show a => a -> String
show Lit
lit

  EVar Name
name `EApp` Exp
e1 `EApp` Exp
e2 | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(|*|) ->
    [| fwdmPar $(Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e1) $(Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e2) |]

  EApp Exp
e1 Exp
e2 -> do
    (Exp -> Exp
wrap, [Name
funname, Name
argname]) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp
e1, Exp
e2]
    Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp
wrap (Name -> Exp
VarE Name
funname Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
argname))

  ELam Pat
pat Exp
e -> do
    Set Name
patbound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
    Exp
e' <- Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> Set Name
patbound) Exp
e
    [| mpure (\ $(Pat -> Q Pat
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Pat
pat) -> $(Exp -> Q Exp
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e')) |]

  ETup [Exp]
es -> do
    (Exp -> Exp
wrap, [Name]
vars) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp]
es
    Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp
wrap (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE 'mpure Exp -> Exp -> Exp
`AppE` [Maybe Exp] -> Exp
TupE ((Name -> Maybe Exp) -> [Name] -> [Maybe Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> (Name -> Exp) -> Name -> Maybe Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE) [Name]
vars))

  ECond Exp
e1 Exp
e2 Exp
e3 -> do
    Exp
e1' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e1
    Name
boolName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"bool"
    Exp
e2' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e2
    Exp
e3' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e3
    Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing
              [Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
boolName) Exp
e1'
              ,Exp -> Stmt
NoBindS (Exp -> Exp -> Exp -> Exp
CondE (Name -> Exp
VarE Name
boolName) Exp
e2' Exp
e3')])

  ELet [DecGroup]
decs Exp
body -> do
    (Exp -> Exp
wrap, [Name]
vars) <- Set Name -> [DecGroup] -> Q (Exp -> Exp, [Name])
ddrDecs Set Name
env [DecGroup]
decs
    Exp
body' <- Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList [Name]
vars) Exp
body
    Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
wrap Exp
body'

  ECase Exp
expr [(Pat, Exp)]
matches -> do
    (Exp -> Exp
wrap, [Name
evar]) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp
expr]
    [(Pat, Exp)]
matches' <- [Q (Pat, Exp)] -> Q [(Pat, Exp)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
      [do Set Name
patbound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
          Pat -> Q ()
ddrPat Pat
pat
          (Pat
pat,) (Exp -> (Pat, Exp)) -> Q Exp -> Q (Pat, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> Set Name
patbound) Exp
rhs
      | (Pat
pat, Exp
rhs) <- [(Pat, Exp)]
matches]
    Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
wrap (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
      Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
evar)
        [Pat -> Body -> [Dec] -> Match
Match Pat
pat (Exp -> Body
NormalB Exp
rhs) [] | (Pat
pat, Exp
rhs) <- [(Pat, Exp)]
matches']

  EList [Exp]
es -> do
    (Exp -> Exp
wrap, [Name]
vars) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp]
es
    Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp
wrap (Name -> Exp
VarE 'mpure Exp -> Exp -> Exp
`AppE` [Exp] -> Exp
ListE ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
vars)))

  ESig Exp
e Type
ty ->
    [| $(Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e) :: FwdM $(Type -> Q Type
ddrType Type
ty) |]

ddrNumBinOp :: Name  -- ^ Primal operator
            -> (Bool, Bool)  -- ^ Whether the gradient uses its (first, second) argument
            -> (Q TH.Exp -> Q TH.Exp -> Q TH.Exp)
                  -- ^ Gradient given inputs (assuming adjoint 1).
                  -- Should return a pair: the partial derivaties with respect
                  -- to the two inputs. The names are variable references for
                  -- the two primal inputs.
            -> Q TH.Exp
ddrNumBinOp :: Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp Name
op (Bool
xused, Bool
yused) Q Exp -> Q Exp -> Q Exp
mkgrad = do
  Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
  Name
yname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"y"
  Name
pxname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"px"
  Name
pyname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"py"
  [| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
       mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
yname) ->
         applyBinaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
yname)
                       $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
op)
                       (\ $(if Bool
xused then Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pxname else Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP)
                          $(if Bool
yused then Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pyname else Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP) ->
                            $(Q Exp -> Q Exp -> Q Exp
mkgrad (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pxname) (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pyname))))) |]

-- | Given list of expressions, returns a wrapper that defines a variable for
-- each item in the list (differentiated), together with a list of the names of
-- those variables.
-- The expressions must all have the same, given, environment.
ddrList :: Env -> [Source.Exp] -> Q (TH.Exp -> TH.Exp, [Name])
ddrList :: Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp]
es = do
  [Name]
names <- (Int -> Q Name) -> [Int] -> Q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
idx -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
idx)) [Int
1 .. [Exp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
es]
  [Exp]
rhss <- (Exp -> Q Exp) -> [Exp] -> Q [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Set Name -> Exp -> Q Exp
ddr Set Name
env) [Exp]
es
  (Exp -> Exp, [Name]) -> Q (Exp -> Exp, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (\Exp
rest ->
            Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing ([Stmt] -> Exp) -> [Stmt] -> Exp
forall a b. (a -> b) -> a -> b
$ [Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
nx) Exp
e | (Name
nx, Exp
e) <- [Name] -> [Exp] -> [(Name, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names [Exp]
rhss] [Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++ [Exp -> Stmt
NoBindS Exp
rest]
         ,[Name]
names)

-- | Assumes the declarations occur in a let block.
-- Returns:
-- * a wrapper that defines all of the names;
-- * the list of defined names.
ddrDecs :: Env -> [DecGroup] -> Q (TH.Exp -> TH.Exp, [Name])
ddrDecs :: Set Name -> [DecGroup] -> Q (Exp -> Exp, [Name])
ddrDecs Set Name
env [DecGroup]
decs = do
  let ddrDecGroups :: Set Name -> [DecGroup] -> Q ([Stmt], [Name])
ddrDecGroups Set Name
_ [] = ([Stmt], [Name]) -> Q ([Stmt], [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])
      ddrDecGroups Set Name
env' (DecGroup
g : [DecGroup]
gs) = do
        (Stmt
stmt, [Name]
bound) <- Set Name -> DecGroup -> Q (Stmt, [Name])
ddrDecGroup Set Name
env' DecGroup
g
        ([Stmt]
rest, [Name]
restbound) <- Set Name -> [DecGroup] -> Q ([Stmt], [Name])
ddrDecGroups (Set Name
env' Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList [Name]
bound) [DecGroup]
gs
        ([Stmt], [Name]) -> Q ([Stmt], [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Stmt
stmt Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt]
rest, [Name]
bound [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ [Name]
restbound)

  ([Stmt]
stmts, [Name]
declared) <- Set Name -> [DecGroup] -> Q ([Stmt], [Name])
ddrDecGroups Set Name
env [DecGroup]
decs

  (Exp -> Exp, [Name]) -> Q (Exp -> Exp, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    (\Exp
body -> Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing ([Stmt] -> Exp) -> [Stmt] -> Exp
forall a b. (a -> b) -> a -> b
$ [Stmt]
stmts [Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++ [Exp -> Stmt
NoBindS Exp
body]
    ,[Name]
declared)

ddrDecGroup :: Env -> DecGroup -> Q (Stmt, [Name])
ddrDecGroup :: Set Name -> DecGroup -> Q (Stmt, [Name])
ddrDecGroup Set Name
env (DecVar Name
name Maybe Type
msig Exp
rhs) = do
  Exp
rhs' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
rhs
  Exp
rhs'' <- case Maybe Type
msig of Just Type
sig -> do Type
sig' <- Type -> Q Type
ddrType Type
sig
                                       Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Type -> Exp
SigE Exp
rhs' (Type -> Type -> Type
AppT (Name -> Type
ConT ''FwdM) Type
sig')
                        Maybe Type
Nothing -> Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
rhs'
  (Stmt, [Name]) -> Q (Stmt, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
name) Exp
rhs'', [Name
name])
ddrDecGroup Set Name
env (DecMutGroup [(Name, Maybe Type, Pat, Exp)]
lams) = do
  let names :: [Name]
names = [Name
name | (Name
name, Maybe Type
_, Pat
_, Exp
_) <- [(Name, Maybe Type, Pat, Exp)]
lams]
  [Dec]
decs <- [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Q [Dec]] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
    [do Pat -> Q ()
ddrPat Pat
pat
        Set Name
bound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
        Exp
rhs' <- Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList [Name]
names Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> Set Name
bound) Exp
rhs
        let dec :: Dec
dec = Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
name) (Exp -> Body
NormalB ([Pat] -> Exp -> Exp
LamE [Pat
pat] Exp
rhs')) []
        case Maybe Type
msig of
          Just Type
sig -> do Type
sig' <- Type -> Q Type
ddrType Type
sig
                         [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Name -> Type -> Dec
SigD Name
name Type
sig', Dec
dec]
          Maybe Type
Nothing -> [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Dec
dec]
    | (Name
name, Maybe Type
msig, Pat
pat, Exp
rhs) <- [(Name, Maybe Type, Pat, Exp)]
lams]
  (Stmt, [Name]) -> Q (Stmt, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Stmt
LetS [Dec]
decs, [Name]
names)

-- | Only checks that the data constructors appearing in the pattern are of
-- supported types.
ddrPat :: Pat -> Q ()
ddrPat :: Pat -> Q ()
ddrPat = \case
  LitP{} -> String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Literals in patterns unsupported"
  VarP{} -> () -> Q ()
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  TupP [Pat]
ps -> (Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
ps
  UnboxedTupP [Pat]
ps -> (Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
ps
  p :: Pat
p@UnboxedSumP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Unboxed sums" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  p :: Pat
p@(ConP Name
name [Type]
tyapps [Pat]
args)
    | Bool -> Bool
not ([Type] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type]
tyapps) -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Type applications in patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
    | Bool
otherwise -> do
        -- ignore the field types; just validity is good enough, assuming that the user's code was okay
        [Type]
_ <- Name -> Q [Type]
checkDatacon Name
name
        (Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
args
  InfixP Pat
p1 Name
name Pat
p2 -> Pat -> Q ()
ddrPat (Name -> [Type] -> [Pat] -> Pat
ConP Name
name [] [Pat
p1, Pat
p2])
  p :: Pat
p@UInfixP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"UInfix patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  ParensP Pat
p -> Pat -> Q ()
ddrPat Pat
p
  p :: Pat
p@TildeP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Irrefutable patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  p :: Pat
p@BangP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Bang patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  AsP Name
_ Pat
p -> Pat -> Q ()
ddrPat Pat
p
  Pat
WildP -> () -> Q ()
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  p :: Pat
p@RecP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Records" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  ListP [Pat]
ps -> (Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
ps
  p :: Pat
p@SigP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Type signatures in patterns, because then I need to rewrite types and I'm lazy" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  p :: Pat
p@ViewP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"View patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))

ddrType :: Type -> Q Type
ddrType :: Type -> Q Type
ddrType = \Type
ty ->
  case Type -> Either Type Type
go Type
ty of
    Left Type
bad -> String -> Q Type
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Type) -> String -> Q Type
forall a b. (a -> b) -> a -> b
$ String
"Don't know how to differentiate (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
show Type
bad String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"), which is a \
                       \part of the type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
show Type
ty
    Right Type
res -> Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
res
  where
    go :: Type -> Either Type Type
    go :: Type -> Either Type Type
go (ConT Name
name)
      | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Double = Type -> Either Type Type
forall a b. b -> Either a b
Right (Name -> Type
ConT ''DN)
      | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Int = Type -> Either Type Type
forall a b. b -> Either a b
Right (Name -> Type
ConT ''Int)
    go (Type
ArrowT `AppT` Type
t1 `AppT` Type
t) = do
      Type
t1' <- Type -> Either Type Type
go Type
t1
      Type
t' <- Type -> Either Type Type
go Type
t
      Type -> Either Type Type
forall a. a -> Either Type a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Either Type Type) -> Type -> Either Type Type
forall a b. (a -> b) -> a -> b
$ Type
ArrowT Type -> Type -> Type
`AppT` Type
t1' Type -> Type -> Type
`AppT` (Name -> Type
ConT ''FwdM Type -> Type -> Type
`AppT` Type
t')
    go (Type
MulArrowT `AppT` PromotedT Name
multi `AppT` Type
t1 `AppT` Type
t)
      | Name
multi Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'Many = Type -> Either Type Type
go (Type
ArrowT Type -> Type -> Type
`AppT` Type
t1 Type -> Type -> Type
`AppT` Type
t)
    go Type
ty =
      case Type -> (Type, [Type])
collectApps Type
ty of
        (TupleT Int
n, [Type]
args) | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n ->
          (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT Int
n) ([Type] -> Type) -> Either Type [Type] -> Either Type Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> Either Type Type) -> [Type] -> Either Type [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Type -> Either Type Type
go [Type]
args
        (ConT Name
name, [Type]
args) ->  -- I hope this one is correct
          (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
name) ([Type] -> Type) -> Either Type [Type] -> Either Type Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> Either Type Type) -> [Type] -> Either Type [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Type -> Either Type Type
go [Type]
args
        (Type, [Type])
_ -> Type -> Either Type Type
forall a b. a -> Either a b
Left Type
ty  -- don't know how to handle this type


-- ----------------------------------------------------------------------
-- The wrapper: interleave and deinterleave
-- ----------------------------------------------------------------------

-- input :: a
-- result :: FwdM (Dt[a], VS.Vector Double -> a)
interleave :: Quote m => Structure -> TH.Exp -> m TH.Exp
interleave :: forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
interleave (Structure SimpleType 'Mono
monotype DataTypes
dtypemap) Exp
input = do
  Map (Name, [SimpleType 'Mono]) Name
helpernames <- [((Name, [SimpleType 'Mono]), Name)]
-> Map (Name, [SimpleType 'Mono]) Name
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([((Name, [SimpleType 'Mono]), Name)]
 -> Map (Name, [SimpleType 'Mono]) Name)
-> m [((Name, [SimpleType 'Mono]), Name)]
-> m (Map (Name, [SimpleType 'Mono]) Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                   [m ((Name, [SimpleType 'Mono]), Name)]
-> m [((Name, [SimpleType 'Mono]), Name)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [((Name
n, [SimpleType 'Mono]
ts),) (Name -> ((Name, [SimpleType 'Mono]), Name))
-> m Name -> m ((Name, [SimpleType 'Mono]), Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Name -> [SimpleType 'Mono] -> String
genDataNameTag String
"inter" Name
n [SimpleType 'Mono]
ts)
                            | (Name
n, [SimpleType 'Mono]
ts) <- DataTypes -> [(Name, [SimpleType 'Mono])]
forall k a. Map k a -> [k]
Map.keys DataTypes
dtypemap]
  [(Name, Exp)]
helperfuns <- [m (Name, Exp)] -> m [(Name, Exp)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [(Map (Name, [SimpleType 'Mono]) Name
helpernames Map (Name, [SimpleType 'Mono]) Name
-> (Name, [SimpleType 'Mono]) -> Name
forall k a. Ord k => Map k a -> k -> a
Map.! (Name
n, [SimpleType 'Mono]
ts),) (Exp -> (Name, Exp)) -> m Exp -> m (Name, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
interleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs
                         | ((Name
n, [SimpleType 'Mono]
ts), [(Name, [SimpleType 'Mono])]
constrs) <- DataTypes
-> [((Name, [SimpleType 'Mono]), [(Name, [SimpleType 'Mono])])]
forall k a. Map k a -> [(k, a)]
Map.assocs DataTypes
dtypemap]
  Exp
mainfun <- Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
interleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames SimpleType 'Mono
monotype
  Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Dec] -> Exp -> Exp
LetE [Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
name) (Exp -> Body
NormalB Exp
fun) []
                | (Name
name, Exp
fun) <- [(Name, Exp)]
helperfuns] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
             Exp
mainfun Exp -> Exp -> Exp
`AppE` Exp
input

-- Given constructors of type T, returns expression of type
-- 'T -> (Dt[T], VS.Vector Double -> T)'. The Map contains for each
-- (type name T', type arguments As') combination that occurs (transitively) in
-- T, the name of a function with type
-- 'T' As' -> (Dt[T' As'], VS.Vector Double -> T' As')'.
interleaveData :: Quote m => Map (Name, [MonoType]) Name -> [(Name, [MonoType])] -> m TH.Exp
interleaveData :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
interleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs = do
  Name
inputvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"input"
  let maxn :: Int
maxn = [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum (((Name, [SimpleType 'Mono]) -> Int)
-> [(Name, [SimpleType 'Mono])] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SimpleType 'Mono] -> Int)
-> ((Name, [SimpleType 'Mono]) -> [SimpleType 'Mono])
-> (Name, [SimpleType 'Mono])
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [SimpleType 'Mono]) -> [SimpleType 'Mono]
forall a b. (a, b) -> b
snd) [(Name, [SimpleType 'Mono])]
constrs)
  [Name]
allinpvars     <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"inp"  String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
  [Name]
allpostvars    <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"inp'" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
  [Name]
allrebuildvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"reb"  String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
  -- These have the inpvars in scope.
  [Exp]
bodies <- [m Exp] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
    [do -- For constructor C f₁…f₃:
        --   do (post₁, rebuild₁) <- $(interleaveType helpernames f₁) inp₁
        --      (post₂, rebuild₂) <- $(interleaveType helpernames f₂) inp₂
        --      (post₃, rebuild₃) <- $(interleaveType helpernames f₃) inp₃
        --      mpure (C post₁ post₂ post₃
        --            ,\arr -> C (rebuild₁ arr) (rebuild₂ arr) (rebuild₃ arr))
        --
        -- interleaveType helpernames (Monotype for T) :: Exp (T -> FwdM (Dt[T], Vector Double -> T))
        let inpvars :: [Name]
inpvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allinpvars
            postvars :: [Name]
postvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allpostvars
            rebuildvars :: [Name]
rebuildvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allrebuildvars
        [Exp]
exps <- (SimpleType 'Mono -> m Exp) -> [SimpleType 'Mono] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
interleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames) [SimpleType 'Mono]
fieldtys
        Name
arrname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"arr"
        Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing ([Stmt] -> Exp) -> [Stmt] -> Exp
forall a b. (a -> b) -> a -> b
$
          [Pat -> Exp -> Stmt
BindS ([Pat] -> Pat
TupP [Name -> Pat
VarP Name
postvar, Name -> Pat
VarP Name
rebuildvar])
                 (Exp
expr Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
inpvar)
          | (Name
inpvar, Name
postvar, Name
rebuildvar, Exp
expr)
               <- [Name] -> [Name] -> [Name] -> [Exp] -> [(Name, Name, Name, Exp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Name]
inpvars [Name]
postvars [Name]
rebuildvars [Exp]
exps]
          [Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++
          [Exp -> Stmt
NoBindS (Exp -> Stmt) -> Exp -> Stmt
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE 'mpure Exp -> Exp -> Exp
`AppE`
              Exp -> Exp -> Exp
pair ((Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
conname) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
postvars))
                   ([Pat] -> Exp -> Exp
LamE [if [Name] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
rebuildvars then Pat
WildP else Name -> Pat
VarP Name
arrname] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                      (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
conname)
                            [Name -> Exp
VarE Name
f Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
arrname | Name
f <- [Name]
rebuildvars])]
    | (Name
conname, [SimpleType 'Mono]
fieldtys) <- [(Name, [SimpleType 'Mono])]
constrs]

  -- \input -> case input of
  --   C₁ inp₁ inp₂ inp₃ -> $(bodies !! 0)
  --   C₂ inp₁ inp₂      -> $(bodies !! 1)
  --   C₃ inp₁ inp₂ inp₃ -> $(bodies !! 2)
  Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
inputvar] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
inputvar)
    [Pat -> Body -> [Dec] -> Match
Match (Name -> [Type] -> [Pat] -> Pat
ConP Name
conname [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
inpvars))
           (Exp -> Body
NormalB Exp
body)
           []
    | ((Name
conname, [SimpleType 'Mono]
fieldtys), Exp
body) <- [(Name, [SimpleType 'Mono])]
-> [Exp] -> [((Name, [SimpleType 'Mono]), Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Name, [SimpleType 'Mono])]
constrs [Exp]
bodies
    , let inpvars :: [Name]
inpvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allinpvars]

-- Given a type T, returns expression of type
-- 'T -> FwdM (Dt[T], Vector Double -> T)'. The Map contains for each
-- (type name T', type arguments As') combination that occurs in the type, the
-- name of a function with type
-- 'T' As' -> FwdM (Dt[T' As'], VS.Vector Double -> T' As')'.
interleaveType :: Quote m => Map (Name, [MonoType]) Name -> MonoType -> m TH.Exp
interleaveType :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
interleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames = \case
  SimpleType 'Mono
DiscreteST -> do
    Name
inpxvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"inpx"
    [| \ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
inpxvar) -> mpure ($(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
inpxvar), \_ -> $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
inpxvar)) |]

  SimpleType 'Mono
ScalarST -> do
    Name
inpxvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"inpx"
    Name
ivar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"i"
    Name
arrvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"arr"
    [| \ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
inpxvar) -> do
           $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
ivar) <- fwdmGenIdInterleave
           mpure (DN $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
inpxvar) (NID (JobID 0) $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
ivar)) C0
                 ,\ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
arrvar) -> $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
arrvar) VS.! $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
ivar)) |]

  ConST Name
tyname [SimpleType 'Mono]
argtys ->
    Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE (Name -> Exp) -> Name -> Exp
forall a b. (a -> b) -> a -> b
$ case (Name, [SimpleType 'Mono])
-> Map (Name, [SimpleType 'Mono]) Name -> Maybe Name
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Name
tyname, [SimpleType 'Mono]
argtys) Map (Name, [SimpleType 'Mono]) Name
helpernames of
                      Just Name
name -> Name
name
                      Maybe Name
Nothing -> String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Helper name not defined? " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Name, [SimpleType 'Mono]) -> String
forall a. Show a => a -> String
show (Name
tyname, [SimpleType 'Mono]
argtys)

-- outexp :: Dt[T]                       -- interleaved program output
-- result :: (T                          -- primal result
--           ,T -> BuildState -> IO ())  -- given adjoint, add initial contributions
deinterleave :: Quote m => Structure -> TH.Exp -> m TH.Exp
deinterleave :: forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
deinterleave (Structure SimpleType 'Mono
monotype DataTypes
dtypemap) Exp
outexp = do
  let dtypes :: [(Name, [SimpleType 'Mono])]
dtypes = DataTypes -> [(Name, [SimpleType 'Mono])]
forall k a. Map k a -> [k]
Map.keys DataTypes
dtypemap
  Map (Name, [SimpleType 'Mono]) Name
helpernames <- [((Name, [SimpleType 'Mono]), Name)]
-> Map (Name, [SimpleType 'Mono]) Name
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([((Name, [SimpleType 'Mono]), Name)]
 -> Map (Name, [SimpleType 'Mono]) Name)
-> m [((Name, [SimpleType 'Mono]), Name)]
-> m (Map (Name, [SimpleType 'Mono]) Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                   [m ((Name, [SimpleType 'Mono]), Name)]
-> m [((Name, [SimpleType 'Mono]), Name)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [((Name
n, [SimpleType 'Mono]
ts),) (Name -> ((Name, [SimpleType 'Mono]), Name))
-> m Name -> m ((Name, [SimpleType 'Mono]), Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Name -> [SimpleType 'Mono] -> String
genDataNameTag String
"deinter" Name
n [SimpleType 'Mono]
ts)
                            | (Name
n, [SimpleType 'Mono]
ts) <- [(Name, [SimpleType 'Mono])]
dtypes]
  [(Name, Exp)]
helperfuns <- [m (Name, Exp)] -> m [(Name, Exp)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [(Map (Name, [SimpleType 'Mono]) Name
helpernames Map (Name, [SimpleType 'Mono]) Name
-> (Name, [SimpleType 'Mono]) -> Name
forall k a. Ord k => Map k a -> k -> a
Map.! (Name
n, [SimpleType 'Mono]
ts),) (Exp -> (Name, Exp)) -> m Exp -> m (Name, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
deinterleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs
                         | ((Name
n, [SimpleType 'Mono]
ts), [(Name, [SimpleType 'Mono])]
constrs) <- DataTypes
-> [((Name, [SimpleType 'Mono]), [(Name, [SimpleType 'Mono])])]
forall k a. Map k a -> [(k, a)]
Map.assocs DataTypes
dtypemap]
  Exp
mainfun <- Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
deinterleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames SimpleType 'Mono
monotype
  Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Dec] -> Exp -> Exp
LetE [Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
name) (Exp -> Body
NormalB Exp
fun) []
                | (Name
name, Exp
fun) <- [(Name, Exp)]
helperfuns] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
             Exp
mainfun Exp -> Exp -> Exp
`AppE` Exp
outexp

-- Dt[T]                            -- interleaved program output
--   -> (T                          -- primal result
--      ,T -> BuildState -> IO ())  -- given adjoint, add initial contributions
-- The Map contains for each (type name T', type arguments As') combination
-- that occurs (transitively) in T, the name of a function with type
-- 'Dt[T' As'] -> (T' As', T' As' -> BuildState -> IO ())'.
deinterleaveData :: Quote m => Map (Name, [MonoType]) Name -> [(Name, [MonoType])] -> m TH.Exp
deinterleaveData :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
deinterleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs = do
  Name
dualvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"out"
  let maxn :: Int
maxn = [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum (((Name, [SimpleType 'Mono]) -> Int)
-> [(Name, [SimpleType 'Mono])] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SimpleType 'Mono] -> Int)
-> ((Name, [SimpleType 'Mono]) -> [SimpleType 'Mono])
-> (Name, [SimpleType 'Mono])
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [SimpleType 'Mono]) -> [SimpleType 'Mono]
forall a b. (a, b) -> b
snd) [(Name, [SimpleType 'Mono])]
constrs)
  [Name]
alldvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"d" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
  [Name]
allpvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"p" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
  [Name]
allfvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"f" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
  [Name]
allavars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"a" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]

  Name
bsvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"bs"

  let composeActions :: [Exp] -> Exp
composeActions [] = [Pat] -> Exp -> Exp
LamE [Pat
WildP] (Name -> Exp
VarE 'return Exp -> Exp -> Exp
`AppE` [Maybe Exp] -> Exp
TupE [])
      composeActions [Exp]
l =
        [Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
bsvar] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (\Exp
a Exp
b -> Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
a) (Name -> Exp
VarE '(>>)) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
b))
                 ((Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
bsvar) [Exp]
l)

  [Exp]
bodies <- [m Exp] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
    [do let dvars :: [Name]
dvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
alldvars
            pvars :: [Name]
pvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allpvars
            fvars :: [Name]
fvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allfvars
            avars :: [Name]
avars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allavars
        [Exp]
exps <- (SimpleType 'Mono -> m Exp) -> [SimpleType 'Mono] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
deinterleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames) [SimpleType 'Mono]
fieldtys
        Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Dec] -> Exp -> Exp
LetE [Pat -> Body -> [Dec] -> Dec
ValD ([Pat] -> Pat
TupP [Name -> Pat
VarP Name
pvar, Name -> Pat
VarP Name
fvar]) (Exp -> Body
NormalB (Exp
expr Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
dvar)) []
                      | (Name
dvar, Name
pvar, Name
fvar, Exp
expr) <- [Name] -> [Name] -> [Name] -> [Exp] -> [(Name, Name, Name, Exp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Name]
dvars [Name]
pvars [Name]
fvars [Exp]
exps] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                   Exp -> Exp -> Exp
pair ((Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
conname) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
pvars))
                        -- irrefutable (partial) pattern: that's what you get with sum types in
                        -- a non-dependent context.
                        ([Pat] -> Exp -> Exp
LamE [Name -> [Type] -> [Pat] -> Pat
ConP Name
conname [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
avars)] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                           -- TODO: is this type signature still necessary now that we've moved
                           -- away from linear types?
                           Exp -> Type -> Exp
SigE ([Exp] -> Exp
composeActions [Name -> Exp
VarE Name
fvar Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
avar
                                                | (Name
fvar, Name
avar) <- [Name] -> [Name] -> [(Name, Name)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
fvars [Name]
avars])
                                (Type
MulArrowT Type -> Type -> Type
`AppT` Name -> Type
ConT 'Many Type -> Type -> Type
`AppT` Name -> Type
ConT ''BuildState Type -> Type -> Type
`AppT` (Name -> Type
ConT ''IO Type -> Type -> Type
`AppT` Int -> Type
TupleT Int
0)))
    | (Name
conname, [SimpleType 'Mono]
fieldtys) <- [(Name, [SimpleType 'Mono])]
constrs]

  Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
dualvar] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
dualvar)
    [Pat -> Body -> [Dec] -> Match
Match (Name -> [Type] -> [Pat] -> Pat
ConP Name
conname [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
dvars))
           (Exp -> Body
NormalB Exp
body)
           []
    | ((Name
conname, [SimpleType 'Mono]
fieldtys), Exp
body) <- [(Name, [SimpleType 'Mono])]
-> [Exp] -> [((Name, [SimpleType 'Mono]), Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Name, [SimpleType 'Mono])]
constrs [Exp]
bodies
    , let dvars :: [Name]
dvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
alldvars]

-- Dt[T]                            -- interleaved program output
--   -> (T                          -- primal result
--      ,T -> BuildState -> IO ())  -- given adjoint, add initial contributions
-- The Map contains for each (type name T', type arguments As') combination
-- that occurs (transitively) in T, the name of a function with type
-- 'Dt[T' As'] -> (T' As', T' As' -> BuildState -> IO ())'.
deinterleaveType :: Quote m => Map (Name, [MonoType]) Name -> MonoType -> m TH.Exp
deinterleaveType :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
deinterleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames = \case
  SimpleType 'Mono
DiscreteST -> do
    Name
dname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"d"
    [| \ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
dname) -> ($(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
dname), \_ _ -> return () :: IO ()) |]

  SimpleType 'Mono
ScalarST -> do
    Name
primalname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"prim"
    Name
idname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"id"
    Name
cbname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"cb"
    Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Pat] -> Exp -> Exp
LamE [Name -> [Type] -> [Pat] -> Pat
ConP 'DN [] [Name -> Pat
VarP Name
primalname, Name -> Pat
VarP Name
idname, Name -> Pat
VarP Name
cbname]] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
      Exp -> Exp -> Exp
pair (Name -> Exp
VarE Name
primalname)
           (Name -> Exp
VarE 'addContrib Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
idname Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
cbname)  -- partially-applied

  ConST Name
tyname [SimpleType 'Mono]
argtys ->
    Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE (Name -> Exp) -> Name -> Exp
forall a b. (a -> b) -> a -> b
$ case (Name, [SimpleType 'Mono])
-> Map (Name, [SimpleType 'Mono]) Name -> Maybe Name
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Name
tyname, [SimpleType 'Mono]
argtys) Map (Name, [SimpleType 'Mono]) Name
helpernames of
                      Just Name
name -> Name
name
                      Maybe Name
Nothing -> String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Helper name not defined? " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Name, [SimpleType 'Mono]) -> String
forall a. Show a => a -> String
show (Name
tyname, [SimpleType 'Mono]
argtys)

-- | Not necessarily unique.
genDataNameTag :: String -> Name -> [MonoType] -> String
genDataNameTag :: String -> Name -> [SimpleType 'Mono] -> String
genDataNameTag String
prefix Name
tyname [SimpleType 'Mono]
argtys = String
prefix String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
goN Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ (SimpleType 'Mono -> String) -> [SimpleType 'Mono] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Char
'_'Char -> ShowS
forall a. a -> [a] -> [a]
:) ShowS -> (SimpleType 'Mono -> String) -> SimpleType 'Mono -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleType 'Mono -> String
goT) [SimpleType 'Mono]
argtys
  where
    goN :: Name -> String
    goN :: Name -> String
goN Name
n = case (Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
filter Char -> Bool
isAlphaNum (Name -> String
forall a. Show a => a -> String
show Name
n) of [] -> String
"xx" ; String
s -> String
s

    goT :: MonoType -> String
    goT :: SimpleType 'Mono -> String
goT SimpleType 'Mono
DiscreteST = String
"i"
    goT SimpleType 'Mono
ScalarST = String
"s"
    goT (ConST Name
n [SimpleType 'Mono]
ts) = Name -> String
goN Name
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ (SimpleType 'Mono -> String) -> [SimpleType 'Mono] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SimpleType 'Mono -> String
goT [SimpleType 'Mono]
ts


-- ----------------------------------------------------------------------
-- Polymorphic numeric operations
-- ----------------------------------------------------------------------
--
-- This is to get around the limitation of TH that we do not know the inferred
-- types of subexpressions in the AD transformation. Hence, for polymorphic
-- primitive operations, we defer the choice of implementation to the Haskell
-- typechecker using a type class.

class NumOperation a where
  type DualNum a = r | r -> a
  applyBinaryOp
    :: DualNum a -> DualNum a  -- arguments
    -> (a -> a -> a)           -- primal
    -> (a -> a -> (a, a))      -- gradient given inputs (assuming adjoint 1)
    -> FwdM (DualNum a)        -- output
  applyUnaryOp
    :: DualNum a               -- argument
    -> (a -> a)                -- primal
    -> (a -> a)                -- derivative given input (assuming adjoint 1)
    -> FwdM (DualNum a)        -- output
  applyUnaryOp2
    :: DualNum a               -- argument
    -> (a -> a)                -- primal
    -> (a -> a -> a)           -- derivative given input and primal result (assuming adjoint 1)
    -> FwdM (DualNum a)        -- output
  applyCmpOp
    :: DualNum a -> DualNum a  -- arguments
    -> (a -> a -> Bool)        -- primal
    -> Bool                    -- output
  fromIntegralOp
    :: Integral b
    => b                       -- argument
    -> FwdM (DualNum a)        -- output

instance NumOperation Double where
  type DualNum Double = DN
  applyBinaryOp :: DualNum Double
-> DualNum Double
-> (Double -> Double -> Double)
-> (Double -> Double -> (Double, Double))
-> FwdM (DualNum Double)
applyBinaryOp (DN Double
x NID
xi Contrib
xcb) (DN Double
y NID
yi Contrib
ycb) Double -> Double -> Double
primal Double -> Double -> (Double, Double)
grad = do
    let (Double
dx, Double
dy) = Double -> Double -> (Double, Double)
grad Double
x Double
y
    NID
i <- FwdM NID
fwdmGenId
    DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN (Double -> Double -> Double
primal Double
x Double
y) NID
i (CEdge -> CEdge -> Contrib
C2 (NID -> Contrib -> Double -> CEdge
CEdge NID
xi Contrib
xcb Double
dx) (NID -> Contrib -> Double -> CEdge
CEdge NID
yi Contrib
ycb Double
dy)))
  applyUnaryOp :: DualNum Double
-> (Double -> Double)
-> (Double -> Double)
-> FwdM (DualNum Double)
applyUnaryOp (DN Double
x NID
xi Contrib
xcb) Double -> Double
primal Double -> Double
grad = do
    NID
i <- FwdM NID
fwdmGenId
    DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN (Double -> Double
primal Double
x) NID
i (CEdge -> Contrib
C1 (NID -> Contrib -> Double -> CEdge
CEdge NID
xi Contrib
xcb (Double -> Double
grad Double
x))))
  applyUnaryOp2 :: DualNum Double
-> (Double -> Double)
-> (Double -> Double -> Double)
-> FwdM (DualNum Double)
applyUnaryOp2 (DN Double
x NID
xi Contrib
xcb) Double -> Double
primal Double -> Double -> Double
grad = do
    NID
i <- FwdM NID
fwdmGenId
    let pr :: Double
pr = Double -> Double
primal Double
x
    DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN Double
pr NID
i (CEdge -> Contrib
C1 (NID -> Contrib -> Double -> CEdge
CEdge NID
xi Contrib
xcb (Double -> Double -> Double
grad Double
x Double
pr))))
  applyCmpOp :: DualNum Double
-> DualNum Double -> (Double -> Double -> Bool) -> Bool
applyCmpOp (DN Double
x NID
_ Contrib
_) (DN Double
y NID
_ Contrib
_) Double -> Double -> Bool
f = Double -> Double -> Bool
f Double
x Double
y
  fromIntegralOp :: forall b. Integral b => b -> FwdM (DualNum Double)
fromIntegralOp b
x = do
    NID
i <- FwdM NID
fwdmGenId
    DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN (b -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
x) NID
i Contrib
C0)

instance NumOperation Int where
  type DualNum Int = Int
  applyBinaryOp :: DualNum Int
-> DualNum Int
-> (Int -> Int -> Int)
-> (Int -> Int -> (Int, Int))
-> FwdM (DualNum Int)
applyBinaryOp DualNum Int
x DualNum Int
y Int -> Int -> Int
primal Int -> Int -> (Int, Int)
_ = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int -> Int
primal Int
DualNum Int
x Int
DualNum Int
y)
  applyUnaryOp :: DualNum Int -> (Int -> Int) -> (Int -> Int) -> FwdM (DualNum Int)
applyUnaryOp DualNum Int
x Int -> Int
primal Int -> Int
_ = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int
primal Int
DualNum Int
x)
  applyUnaryOp2 :: DualNum Int
-> (Int -> Int) -> (Int -> Int -> Int) -> FwdM (DualNum Int)
applyUnaryOp2 DualNum Int
x Int -> Int
primal Int -> Int -> Int
_ = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int
primal Int
DualNum Int
x)
  applyCmpOp :: DualNum Int -> DualNum Int -> (Int -> Int -> Bool) -> Bool
applyCmpOp DualNum Int
x DualNum Int
y Int -> Int -> Bool
f = Int -> Int -> Bool
f Int
DualNum Int
x Int
DualNum Int
y
  fromIntegralOp :: forall b. Integral b => b -> FwdM (DualNum Int)
fromIntegralOp b
x = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
x)


-- ----------------------------------------------------------------------
-- Further utility functions
-- ----------------------------------------------------------------------

-- | Returns the types of the fields of the data constructor if valid
checkDatacon :: Name -> Q [Type]
checkDatacon :: Name -> Q [Type]
checkDatacon Name
name = do
  Type
conty <- Name -> Q Type
reifyType Name
name
  (Type
tycon, [Type]
tyargs, [Type]
fieldtys) <- case Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
conty of
    Just (Type, [Type], [Type])
ty -> (Type, [Type], [Type]) -> Q (Type, [Type], [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type, [Type], [Type])
ty
    Maybe (Type, [Type], [Type])
Nothing -> String -> Q (Type, [Type], [Type])
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (Type, [Type], [Type]))
-> String -> Q (Type, [Type], [Type])
forall a b. (a -> b) -> a -> b
$ String
"Could not deduce root type from type of data constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Ppr a => a -> String
pprint Name
name
  [Name]
tyvars <- case (Type -> Maybe Name) -> [Type] -> Maybe [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\case VarT Name
n -> Name -> Maybe Name
forall a. a -> Maybe a
Just Name
n
                                 Type
_ -> Maybe Name
forall a. Maybe a
Nothing)
                          [Type]
tyargs of
              Just [Name]
vars -> [Name] -> Q [Name]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Name]
vars
              Maybe [Name]
Nothing -> String -> Q [Name]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Normal constructor has GADT properties?"
  -- Check that we can successfully derive the structure of the type applied to
  -- all-() type arguments. This _should_ be equivalent to a more general
  -- analysis that considers the type variables actual abstract entities.
  let appliedType :: Type
appliedType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
tycon ((Name -> Type) -> [Name] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
_ -> Name -> Type
ConT ''()) [Name]
tyvars)
  (SimpleType 'Mono, DataTypes)
_ <- Type -> Q (SimpleType 'Mono, DataTypes)
exploreRecursiveType Type
appliedType
  [Type] -> Q [Type]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
fieldtys

-- | Given the type of a data constructor, return:
-- - the name of the type it is a constructor of (usually 'ConT name', but also 'TupleT n');
-- - the instantiations of the type parameters of that type in the types of the constructor's fields;
-- - the types of the fields of the constructor
fromDataconType :: Type -> Maybe (Type, [Type], [Type])
fromDataconType :: Type -> Maybe (Type, [Type], [Type])
fromDataconType (ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
t) = Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
t
fromDataconType (Type
ArrowT `AppT` Type
ty `AppT` Type
t) =
  (\(Type
n, [Type]
typarams, [Type]
tys) -> (Type
n, [Type]
typarams, Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
tys)) ((Type, [Type], [Type]) -> (Type, [Type], [Type]))
-> Maybe (Type, [Type], [Type]) -> Maybe (Type, [Type], [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
t
fromDataconType (Type
MulArrowT `AppT` PromotedT Name
multi `AppT` Type
ty `AppT` Type
t)
  | Name
multi Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'One = (\(Type
n, [Type]
typarams, [Type]
tys) -> (Type
n, [Type]
typarams, Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
tys)) ((Type, [Type], [Type]) -> (Type, [Type], [Type]))
-> Maybe (Type, [Type], [Type]) -> Maybe (Type, [Type], [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
t
  | Bool
otherwise = Maybe (Type, [Type], [Type])
forall a. Maybe a
Nothing
fromDataconType Type
t = (\(Type
n, [Type]
typarams) -> (Type
n, [Type]
typarams, [])) ((Type, [Type]) -> (Type, [Type], [Type]))
-> Maybe (Type, [Type]) -> Maybe (Type, [Type], [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type])
extractTypeCon Type
t

extractTypeCon :: Type -> Maybe (Type, [Type])
extractTypeCon :: Type -> Maybe (Type, [Type])
extractTypeCon (AppT Type
t Type
arg) = ([Type] -> [Type]) -> (Type, [Type]) -> (Type, [Type])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ([Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type
arg]) ((Type, [Type]) -> (Type, [Type]))
-> Maybe (Type, [Type]) -> Maybe (Type, [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type])
extractTypeCon Type
t
extractTypeCon (ConT Name
n) = (Type, [Type]) -> Maybe (Type, [Type])
forall a. a -> Maybe a
Just (Name -> Type
ConT Name
n, [])
extractTypeCon Type
ListT = (Type, [Type]) -> Maybe (Type, [Type])
forall a. a -> Maybe a
Just (Name -> Type
ConT ''[], [])
extractTypeCon (TupleT Int
n) = (Type, [Type]) -> Maybe (Type, [Type])
forall a. a -> Maybe a
Just (Int -> Type
TupleT Int
n, [])
extractTypeCon Type
_ = Maybe (Type, [Type])
forall a. Maybe a
Nothing

-- | Only unpacks normal function arrows, not linear ones.
unpackFunctionType :: Type -> ([Type], Type)
unpackFunctionType :: Type -> ([Type], Type)
unpackFunctionType (Type
ArrowT `AppT` Type
ty `AppT` Type
t) = ([Type] -> [Type]) -> ([Type], Type) -> ([Type], Type)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:) (Type -> ([Type], Type)
unpackFunctionType Type
t)
unpackFunctionType (Type
MulArrowT `AppT` PromotedT Name
multi `AppT` Type
ty `AppT` Type
t)
  | Name
multi Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'Many = ([Type] -> [Type]) -> ([Type], Type) -> ([Type], Type)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:) (Type -> ([Type], Type)
unpackFunctionType Type
t)
unpackFunctionType Type
t = ([], Type
t)

isDiscrete :: Type -> Bool
isDiscrete :: Type -> Bool
isDiscrete (ConT Name
n) = Name
n Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
discreteTypeNames
isDiscrete t :: Type
t@AppT{} =
  let (Type
hd, [Type]
args) = Type -> (Type, [Type])
collectApps Type
t
  in case Type
hd of
       TupleT Int
n | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n -> (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
isDiscrete [Type]
args
       Type
ListT | [Type
arg] <- [Type]
args -> Type -> Bool
isDiscrete Type
arg
       Type
_ -> Bool
False
isDiscrete Type
_ = Bool
False

collectApps :: Type -> (Type, [Type])
collectApps :: Type -> (Type, [Type])
collectApps = \Type
t -> Type -> [Type] -> (Type, [Type])
go Type
t []
  where
    go :: Type -> [Type] -> (Type, [Type])
go (AppT Type
t1 Type
t2) [Type]
prefix = Type -> [Type] -> (Type, [Type])
go Type
t1 (Type
t2 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
prefix)
    go Type
t [Type]
prefix = (Type
t, [Type]
prefix)

-- | Given an expression `e`, wraps it in `n` kleisli-lifted lambdas like
--
-- > \x1 -> pure (\x2 -> pure (... \xn -> pure (e x1 ... xn)))
liftKleisliN :: Int -> TH.Exp -> Q TH.Exp
liftKleisliN :: Int -> Exp -> Q Exp
liftKleisliN Int
0 Exp
e = Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
liftKleisliN Int
n Exp
e = do
  Name
name <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
  [| \ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
name) -> mpure $(Int -> Exp -> Q Exp
liftKleisliN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Exp
e Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
name)) |]

pair :: TH.Exp -> TH.Exp -> TH.Exp
pair :: Exp -> Exp -> Exp
pair Exp
e1 Exp
e2 = [Maybe Exp] -> Exp
TupE [Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e1, Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e2]

tvbName :: TyVarBndr () -> Name
tvbName :: TyVarBndr () -> Name
tvbName (PlainTV Name
n ()
_) = Name
n
tvbName (KindedTV Name
n ()
_ Type
_) = Name
n

mapUnionsWithKey :: (Foldable f, Ord k) => (k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey :: forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey k -> a -> a -> a
f = (Map k a -> Map k a -> Map k a)
-> Map k a -> f (Map k a) -> Map k a
forall a b. (a -> b -> b) -> b -> f a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
forall k a.
Ord k =>
(k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWithKey k -> a -> a -> a
f) Map k a
forall a. Monoid a => a
mempty

kENABLE_EVLOG :: Bool
kENABLE_EVLOG :: Bool
kENABLE_EVLOG = Bool
False

evlog :: String -> IO ()
evlog :: String -> IO ()
evlog String
_ | Bool -> Bool
not Bool
kENABLE_EVLOG = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
evlog String
s = do
  TimeSpec
clk <- Clock -> IO TimeSpec
getTime Clock
Monotonic
  MVar Handle -> (Handle -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar Handle
evlogfile ((Handle -> IO ()) -> IO ()) -> (Handle -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Handle
f -> do
    Handle -> ByteString -> IO ()
BS.hPut Handle
f (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> ByteString
BS8.pack (Double -> String
forall a. Show a => a -> String
show (Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TimeSpec -> Integer
toNanoSecs TimeSpec
clk) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1e9 :: Double) String -> ShowS
forall a. [a] -> [a] -> [a]
++ Char
' ' Char -> ShowS
forall a. a -> [a] -> [a]
: String
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n")
    Handle -> IO ()
hFlush Handle
f

{-# NOINLINE evlogfile #-}
evlogfile :: MVar Handle
evlogfile :: MVar Handle
evlogfile = IO (MVar Handle) -> MVar Handle
forall a. IO a -> a
unsafePerformIO (IO (MVar Handle) -> MVar Handle)
-> IO (MVar Handle) -> MVar Handle
forall a b. (a -> b) -> a -> b
$ Handle -> IO (MVar Handle)
forall a. a -> IO (MVar a)
newMVar (Handle -> IO (MVar Handle)) -> IO Handle -> IO (MVar Handle)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> IOMode -> IO Handle
openFile String
"evlogfile.txt" IOMode
AppendMode