{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE ViewPatterns #-}
module Language.Haskell.ReverseAD.TH (
reverseAD,
reverseAD',
Structure,
structureFromTypeable,
structureFromType,
(|*|),
evlog,
) where
import Control.Concurrent
import Control.Monad (when)
import Control.Parallel (par)
import Data.Bifunctor (first, second)
import Data.Char (isAlphaNum)
import Data.List (zip4, intercalate)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.IORef
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Typeable
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Storable.Mutable as MVS
import Data.Word (Word8, Word16, Word32, Word64)
import GHC.Exts (Multiplicity(..))
import Language.Haskell.TH
import Language.Haskell.TH.Syntax as TH hiding (lift)
import System.Clock
import System.IO.Unsafe
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import Control.DeepSeq
import Control.Exception (evaluate)
import Control.Monad.IO.Class
import System.IO
import Control.Concurrent.ThreadPool
import Data.Vector.Storable.Mutable.CAS
import Language.Haskell.ReverseAD.TH.Orphans ()
import Language.Haskell.ReverseAD.TH.Source as Source
import Language.Haskell.ReverseAD.TH.Translate
kDEBUG :: Bool
kDEBUG :: Bool
kDEBUG = Bool
False
(|*|) :: a -> b -> (a, b)
a
x |*| :: forall a b. a -> b -> (a, b)
|*| b
y = a
x a -> (a, b) -> (a, b)
forall a b. a -> b -> b
`par` b
y b -> (a, b) -> (a, b)
forall a b. a -> b -> b
`par` (a
x, b
y)
newtype JobID = JobID Int
deriving (Int -> JobID -> ShowS
[JobID] -> ShowS
JobID -> String
(Int -> JobID -> ShowS)
-> (JobID -> String) -> ([JobID] -> ShowS) -> Show JobID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JobID -> ShowS
showsPrec :: Int -> JobID -> ShowS
$cshow :: JobID -> String
show :: JobID -> String
$cshowList :: [JobID] -> ShowS
showList :: [JobID] -> ShowS
Show)
data BeforeJob
= Start
| Fork !JobDescr !JobDescr !JobDescr
deriving (Int -> BeforeJob -> ShowS
[BeforeJob] -> ShowS
BeforeJob -> String
(Int -> BeforeJob -> ShowS)
-> (BeforeJob -> String)
-> ([BeforeJob] -> ShowS)
-> Show BeforeJob
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BeforeJob -> ShowS
showsPrec :: Int -> BeforeJob -> ShowS
$cshow :: BeforeJob -> String
show :: BeforeJob -> String
$cshowList :: [BeforeJob] -> ShowS
showList :: [BeforeJob] -> ShowS
Show)
data JobDescr = JobDescr
!BeforeJob
{-# UNPACK #-} !JobID
{-# UNPACK #-} !Int
deriving (Int -> JobDescr -> ShowS
[JobDescr] -> ShowS
JobDescr -> String
(Int -> JobDescr -> ShowS)
-> (JobDescr -> String) -> ([JobDescr] -> ShowS) -> Show JobDescr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JobDescr -> ShowS
showsPrec :: Int -> JobDescr -> ShowS
$cshow :: JobDescr -> String
show :: JobDescr -> String
$cshowList :: [JobDescr] -> ShowS
showList :: [JobDescr] -> ShowS
Show)
newtype FwdM a = FwdM
(IORef JobID
-> JobDescr
-> (JobDescr -> a -> IO ()) -> IO ())
instance Functor FwdM where
fmap :: forall a b. (a -> b) -> FwdM a -> FwdM b
fmap a -> b
f (FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g) = (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b)
-> (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd JobDescr -> b -> IO ()
k ->
IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g IORef JobID
jr JobDescr
jd ((JobDescr -> a -> IO ()) -> IO ())
-> (JobDescr -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd1 !a
x ->
JobDescr -> b -> IO ()
k JobDescr
jd1 (a -> b
f a
x)
instance Applicative FwdM where
pure :: forall a. a -> FwdM a
pure !a
x = (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM (\IORef JobID
_ !JobDescr
jd JobDescr -> a -> IO ()
k -> JobDescr -> a -> IO ()
k JobDescr
jd a
x)
FwdM IORef JobID -> JobDescr -> (JobDescr -> (a -> b) -> IO ()) -> IO ()
f <*> :: forall a b. FwdM (a -> b) -> FwdM a -> FwdM b
<*> FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g = (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b)
-> (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd JobDescr -> b -> IO ()
k ->
IORef JobID -> JobDescr -> (JobDescr -> (a -> b) -> IO ()) -> IO ()
f IORef JobID
jr JobDescr
jd ((JobDescr -> (a -> b) -> IO ()) -> IO ())
-> (JobDescr -> (a -> b) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd1 !a -> b
fun ->
IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
g IORef JobID
jr JobDescr
jd1 ((JobDescr -> a -> IO ()) -> IO ())
-> (JobDescr -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd2 !a
x ->
JobDescr -> b -> IO ()
k JobDescr
jd2 (a -> b
fun a
x)
instance Monad FwdM where
FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f >>= :: forall a b. FwdM a -> (a -> FwdM b) -> FwdM b
>>= a -> FwdM b
g = (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b)
-> (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ())
-> FwdM b
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd JobDescr -> b -> IO ()
k ->
IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f IORef JobID
jr JobDescr
jd ((JobDescr -> a -> IO ()) -> IO ())
-> (JobDescr -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !JobDescr
jd1 !a
x ->
let FwdM IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
h = a -> FwdM b
g a
x
in IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
h IORef JobID
jr JobDescr
jd1 JobDescr -> b -> IO ()
k
instance MonadIO FwdM where
liftIO :: forall a. IO a -> FwdM a
liftIO IO a
m = (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a)
-> (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
forall a b. (a -> b) -> a -> b
$ \IORef JobID
_ !JobDescr
jd JobDescr -> a -> IO ()
k -> IO a
m IO a -> (a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= JobDescr -> a -> IO ()
k JobDescr
jd
mpure :: a -> FwdM a
mpure :: forall a. a -> FwdM a
mpure = a -> FwdM a
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
{-# NOINLINE runFwdM #-}
runFwdM :: FwdM a -> (JobDescr, JobID, a)
runFwdM :: forall a. FwdM a -> (JobDescr, JobID, a)
runFwdM (FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f) = IO (JobDescr, JobID, a) -> (JobDescr, JobID, a)
forall a. IO a -> a
unsafePerformIO (IO (JobDescr, JobID, a) -> (JobDescr, JobID, a))
-> IO (JobDescr, JobID, a) -> (JobDescr, JobID, a)
forall a b. (a -> b) -> a -> b
$ do
String -> IO ()
evlog String
"fwdm start"
IORef JobID
jiref <- JobID -> IO (IORef JobID)
forall a. a -> IO (IORef a)
newIORef (Int -> JobID
JobID Int
1)
MVar (JobDescr, a)
resvar <- IO (MVar (JobDescr, a))
forall a. IO (MVar a)
newEmptyMVar
IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f IORef JobID
jiref (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
Start (Int -> JobID
JobID Int
0) Int
0) (((JobDescr, a) -> IO ()) -> JobDescr -> a -> IO ()
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (MVar (JobDescr, a) -> (JobDescr, a) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (JobDescr, a)
resvar))
(JobDescr
jd, a
y) <- MVar (JobDescr, a) -> IO (JobDescr, a)
forall a. MVar a -> IO a
takeMVar MVar (JobDescr, a)
resvar
JobID
nextji <- IORef JobID -> IO JobID
forall a. IORef a -> IO a
readIORef IORef JobID
jiref
String -> IO ()
evlog String
"fwdm end"
(JobDescr, JobID, a) -> IO (JobDescr, JobID, a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (JobDescr
jd, JobID
nextji, a
y)
fwdmPar :: FwdM a -> FwdM b -> FwdM (a, b)
fwdmPar :: forall a b. FwdM a -> FwdM b -> FwdM (a, b)
fwdmPar (FwdM IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f1) (FwdM IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
f2) = (IORef JobID -> JobDescr -> (JobDescr -> (a, b) -> IO ()) -> IO ())
-> FwdM (a, b)
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID
-> JobDescr -> (JobDescr -> (a, b) -> IO ()) -> IO ())
-> FwdM (a, b))
-> (IORef JobID
-> JobDescr -> (JobDescr -> (a, b) -> IO ()) -> IO ())
-> FwdM (a, b)
forall a b. (a -> b) -> a -> b
$ \IORef JobID
jr !JobDescr
jd0 JobDescr -> (a, b) -> IO ()
k -> do
(JobID
ji1, JobID
ji2, JobID
ji3) <-
IORef JobID
-> (JobID -> (JobID, (JobID, JobID, JobID)))
-> IO (JobID, JobID, JobID)
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef JobID
jr (\(JobID Int
j) -> (Int -> JobID
JobID (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3), (Int -> JobID
JobID Int
j, Int -> JobID
JobID (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1), Int -> JobID
JobID (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2))))
String -> IO ()
debug String
"! Starting fork"
MVar (JobDescr, a)
cell1 <- IO (MVar (JobDescr, a))
forall a. IO (MVar a)
newEmptyMVar
MVar (JobDescr, b)
cell2 <- IO (MVar (JobDescr, b))
forall a. IO (MVar a)
newEmptyMVar
Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ()
f1 IORef JobID
jr (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
Start JobID
ji1 Int
0)
(\JobDescr
jd1 a
x -> String -> IO ()
debug String
"! join left" IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar (JobDescr, a) -> (JobDescr, a) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (JobDescr, a)
cell1 (JobDescr
jd1, a
x)))
Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (IORef JobID -> JobDescr -> (JobDescr -> b -> IO ()) -> IO ()
f2 IORef JobID
jr (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
Start JobID
ji2 Int
0)
(\JobDescr
jd2 b
y -> String -> IO ()
debug String
"! join right" IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar (JobDescr, b) -> (JobDescr, b) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (JobDescr, b)
cell2 (JobDescr
jd2, b
y)))
ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
(JobDescr
jd1, a
x) <- MVar (JobDescr, a) -> IO (JobDescr, a)
forall a. MVar a -> IO a
takeMVar MVar (JobDescr, a)
cell1
(JobDescr
jd2, b
y) <- MVar (JobDescr, b) -> IO (JobDescr, b)
forall a. MVar a -> IO a
takeMVar MVar (JobDescr, b)
cell2
String -> IO ()
debug String
"! Joined"
Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> (a, b) -> IO ()
k (BeforeJob -> JobID -> Int -> JobDescr
JobDescr (JobDescr -> JobDescr -> JobDescr -> BeforeJob
Fork JobDescr
jd0 JobDescr
jd1 JobDescr
jd2) JobID
ji3 Int
0) (a
x, b
y))
() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
data NID = NID {-# UNPACK #-} !JobID
{-# UNPACK #-} !Int
deriving (Int -> NID -> ShowS
[NID] -> ShowS
NID -> String
(Int -> NID -> ShowS)
-> (NID -> String) -> ([NID] -> ShowS) -> Show NID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NID -> ShowS
showsPrec :: Int -> NID -> ShowS
$cshow :: NID -> String
show :: NID -> String
$cshowList :: [NID] -> ShowS
showList :: [NID] -> ShowS
Show)
fwdmGenId :: FwdM NID
fwdmGenId :: FwdM NID
fwdmGenId = (IORef JobID -> JobDescr -> (JobDescr -> NID -> IO ()) -> IO ())
-> FwdM NID
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> NID -> IO ()) -> IO ())
-> FwdM NID)
-> (IORef JobID -> JobDescr -> (JobDescr -> NID -> IO ()) -> IO ())
-> FwdM NID
forall a b. (a -> b) -> a -> b
$ \IORef JobID
_ (JobDescr BeforeJob
prev JobID
ji Int
i) JobDescr -> NID -> IO ()
k -> JobDescr -> NID -> IO ()
k (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
prev JobID
ji (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) (JobID -> Int -> NID
NID JobID
ji Int
i)
fwdmGenIdInterleave :: FwdM Int
fwdmGenIdInterleave :: FwdM Int
fwdmGenIdInterleave = (IORef JobID -> JobDescr -> (JobDescr -> Int -> IO ()) -> IO ())
-> FwdM Int
forall a.
(IORef JobID -> JobDescr -> (JobDescr -> a -> IO ()) -> IO ())
-> FwdM a
FwdM ((IORef JobID -> JobDescr -> (JobDescr -> Int -> IO ()) -> IO ())
-> FwdM Int)
-> (IORef JobID -> JobDescr -> (JobDescr -> Int -> IO ()) -> IO ())
-> FwdM Int
forall a b. (a -> b) -> a -> b
$ \IORef JobID
_ (JobDescr BeforeJob
prev ji :: JobID
ji@(JobID Int
jiInt) Int
i) JobDescr -> Int -> IO ()
k ->
if Int
jiInt Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
then JobDescr -> Int -> IO ()
k (BeforeJob -> JobID -> Int -> JobDescr
JobDescr BeforeJob
prev JobID
ji (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) Int
i
else String -> IO ()
forall a. HasCallStack => String -> a
error String
"fwdmGenIdInterleave: not on main thread"
type MaybeBuildState = MV.IOVector (Maybe (MV.IOVector Contrib, MVS.IOVector Double))
type BuildState = MV.IOVector (MV.IOVector Contrib, MVS.IOVector Double)
data CEdge = CEdge {-# UNPACK #-} !NID
!Contrib
{-# UNPACK #-} !Double
data Contrib
= C0
| C1 {-# UNPACK #-} !CEdge
| C2 {-# UNPACK #-} !CEdge {-# UNPACK #-} !CEdge
| C3 ![CEdge]
debugContrib :: Contrib -> String
debugContrib :: Contrib -> String
debugContrib Contrib
C0 = String
"Contrib []"
debugContrib (C1 CEdge
e) = String
"Contrib [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ CEdge -> String
debugCEdge CEdge
e String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
debugContrib (C2 CEdge
e1 CEdge
e2) = String
"Contrib [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ CEdge -> String
debugCEdge CEdge
e1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"," String -> ShowS
forall a. [a] -> [a] -> [a]
++ CEdge -> String
debugCEdge CEdge
e2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
debugContrib (C3 [CEdge]
l) = String
"Contrib [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"," ((CEdge -> String) -> [CEdge] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map CEdge -> String
debugCEdge [CEdge]
l) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
debugCEdge :: CEdge -> String
debugCEdge :: CEdge -> String
debugCEdge (CEdge NID
nid Contrib
_ Double
d) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ NID -> String
forall a. Show a => a -> String
show NID
nid String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", <>, " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
d String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
allocBS :: JobDescr -> MaybeBuildState -> IO ()
allocBS :: JobDescr -> MaybeBuildState -> IO ()
allocBS JobDescr
topjobdescr MaybeBuildState
threads = JobDescr -> IO ()
go JobDescr
topjobdescr
where
go :: JobDescr -> IO ()
go :: JobDescr -> IO ()
go (JobDescr BeforeJob
prev (JobID Int
ji) Int
n) = do
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
-> Int -> IO (Maybe (IOVector Contrib, IOVector Double))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MaybeBuildState
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
threads Int
ji IO (Maybe (IOVector Contrib, IOVector Double))
-> (Maybe (IOVector Contrib, IOVector Double) -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (IOVector Contrib, IOVector Double)
_ -> String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"allocBS: already allocated? (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
ji String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
Maybe (IOVector Contrib, IOVector Double)
Nothing -> do
IOVector Contrib
cbarr <- Int -> Contrib -> IO (MVector (PrimState IO) Contrib)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
n Contrib
C0
IOVector Double
adjarr <- Int -> Double -> IO (MVector (PrimState IO) Double)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> a -> m (MVector (PrimState m) a)
MVS.replicate Int
n Double
0.0
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
-> Int -> Maybe (IOVector Contrib, IOVector Double) -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MaybeBuildState
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
threads Int
ji ((IOVector Contrib, IOVector Double)
-> Maybe (IOVector Contrib, IOVector Double)
forall a. a -> Maybe a
Just (IOVector Contrib
cbarr, IOVector Double
adjarr))
case BeforeJob
prev of
BeforeJob
Start -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Fork JobDescr
parent JobDescr
jd1 JobDescr
jd2 -> do
JobDescr -> IO ()
go JobDescr
jd1
JobDescr -> IO ()
go JobDescr
jd2
JobDescr -> IO ()
go JobDescr
parent
resolve :: JobDescr -> BuildState -> IO ()
resolve :: JobDescr -> BuildState -> IO ()
resolve JobDescr
terminalJob BuildState
threads = do
MVar ()
cell <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
JobDescr -> IO () -> IO ()
resolveTask JobDescr
terminalJob (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
cell ())
MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
cell
where
resolveTask :: JobDescr -> IO () -> IO ()
resolveTask :: JobDescr -> IO () -> IO ()
resolveTask (JobDescr BeforeJob
prev JobID
ji Int
i) IO ()
k = do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
kDEBUG (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Enter job " String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show JobID
ji String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" from i=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
(IOVector Contrib
cbarr, IOVector Double
adjarr) <- MVector (PrimState IO) (IOVector Contrib, IOVector Double)
-> Int -> IO (IOVector Contrib, IOVector Double)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read BuildState
MVector (PrimState IO) (IOVector Contrib, IOVector Double)
threads (let JobID Int
j = JobID
ji in Int
j)
JobID -> Int -> IOVector Contrib -> IOVector Double -> IO ()
resolveJob JobID
ji (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) IOVector Contrib
cbarr IOVector Double
adjarr
case BeforeJob
prev of
BeforeJob
Start -> IO ()
k
Fork JobDescr
jd0 JobDescr
jd1 JobDescr
jd2 -> do
let jidOf :: JobDescr -> JobID
jidOf (JobDescr BeforeJob
_ JobID
n Int
_) = JobID
n
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
kDEBUG (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Forking jobs (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show (JobDescr -> JobID
jidOf JobDescr
jd1) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") and (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show (JobDescr -> JobID
jidOf JobDescr
jd2) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"); parent (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ JobID -> String
forall a. Show a => a -> String
show (JobDescr -> JobID
jidOf JobDescr
jd0) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
MVar ()
cell1 <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
MVar ()
cell2 <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> IO () -> IO ()
resolveTask JobDescr
jd1 (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
cell1 ()))
Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> IO () -> IO ()
resolveTask JobDescr
jd2 (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
cell2 ()))
ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
cell1
MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
cell2
Pool -> IO () -> IO ()
submitJob Pool
globalThreadPool (JobDescr -> IO () -> IO ()
resolveTask JobDescr
jd0 IO ()
k)
() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
resolveJob :: JobID -> Int -> MV.IOVector Contrib -> MVS.IOVector Double -> IO ()
resolveJob :: JobID -> Int -> IOVector Contrib -> IOVector Double -> IO ()
resolveJob JobID
_ (-1) IOVector Contrib
_ IOVector Double
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
resolveJob JobID
jid Int
i IOVector Contrib
cbarr IOVector Double
adjarr = do
Contrib
cb <- MVector (PrimState IO) Contrib -> Int -> IO Contrib
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read IOVector Contrib
MVector (PrimState IO) Contrib
cbarr Int
i
Double
adj <- MVector (PrimState IO) Double -> Int -> IO Double
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m a
MVS.read IOVector Double
MVector (PrimState IO) Double
adjarr Int
i
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
kDEBUG (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Apply (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ NID -> String
forall a. Show a => a -> String
show (JobID -> Int -> NID
NID JobID
jid Int
i) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") [adj=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
adj String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Contrib -> String
debugContrib Contrib
cb
Contrib -> Double -> IO ()
apply Contrib
cb Double
adj
JobID -> Int -> IOVector Contrib -> IOVector Double -> IO ()
resolveJob JobID
jid (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) IOVector Contrib
cbarr IOVector Double
adjarr
apply :: Contrib -> Double -> IO ()
apply :: Contrib -> Double -> IO ()
apply Contrib
C0 Double
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
apply (C1 CEdge
e) Double
a = Double -> CEdge -> IO ()
applyEdge Double
a CEdge
e
apply (C2 CEdge
e1 CEdge
e2) Double
a = Double -> CEdge -> IO ()
applyEdge Double
a CEdge
e1 IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> CEdge -> IO ()
applyEdge Double
a CEdge
e2
apply (C3 [CEdge]
l) Double
a = (CEdge -> IO ()) -> [CEdge] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Double -> CEdge -> IO ()
applyEdge Double
a) [CEdge]
l
applyEdge :: Double -> CEdge -> IO ()
applyEdge Double
a (CEdge NID
nid Contrib
cb Double
d) = NID -> Contrib -> Double -> BuildState -> IO ()
addContrib NID
nid Contrib
cb (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
d) BuildState
threads
addContrib :: NID -> Contrib -> Double -> BuildState -> IO ()
addContrib :: NID -> Contrib -> Double -> BuildState -> IO ()
addContrib (NID (JobID Int
ji) Int
i) Contrib
cb Double
adj BuildState
threads = do
(IOVector Contrib
cbarr, IOVector Double
adjarr) <- MVector (PrimState IO) (IOVector Contrib, IOVector Double)
-> Int -> IO (IOVector Contrib, IOVector Double)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read BuildState
MVector (PrimState IO) (IOVector Contrib, IOVector Double)
threads Int
ji
MVector (PrimState IO) Contrib -> Int -> Contrib -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write IOVector Contrib
MVector (PrimState IO) Contrib
cbarr Int
i Contrib
cb
let loop :: Double -> IO ()
loop Double
acc = do
(Bool
success, Double
old) <- IOVector Double -> Int -> Double -> Double -> IO (Bool, Double)
casIOVectorDouble IOVector Double
adjarr Int
i Double
acc (Double
acc Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
adj)
if Bool
success then () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return () else Double -> IO ()
loop Double
old
Double
orig <- MVector (PrimState IO) Double -> Int -> IO Double
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m a
MVS.read IOVector Double
MVector (PrimState IO) Double
adjarr Int
i
Double -> IO ()
loop Double
orig
{-# NOINLINE dualpass #-}
dualpass :: JobDescr -> JobID -> (d -> BuildState -> IO ()) -> d -> VS.Vector Double
dualpass :: forall d.
JobDescr
-> JobID -> (d -> BuildState -> IO ()) -> d -> Vector Double
dualpass JobDescr
finaljob (JobID Int
numjobs) d -> BuildState -> IO ()
backprop d
adj = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
String -> IO ()
evlog String
"dual start"
MaybeBuildState
threads' <- Int
-> Maybe (IOVector Contrib, IOVector Double)
-> IO
(MVector
(PrimState IO) (Maybe (IOVector Contrib, IOVector Double)))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
numjobs Maybe (IOVector Contrib, IOVector Double)
forall a. Maybe a
Nothing
JobDescr -> MaybeBuildState -> IO ()
allocBS JobDescr
finaljob MaybeBuildState
threads'
BuildState
threads <- Int
-> (Int -> IO (IOVector Contrib, IOVector Double))
-> IO (MVector (PrimState IO) (IOVector Contrib, IOVector Double))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> (Int -> m a) -> m (MVector (PrimState m) a)
MV.generateM Int
numjobs (\Int
i ->
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
-> Int -> IO (Maybe (IOVector Contrib, IOVector Double))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MaybeBuildState
MVector (PrimState IO) (Maybe (IOVector Contrib, IOVector Double))
threads' Int
i IO (Maybe (IOVector Contrib, IOVector Double))
-> (Maybe (IOVector Contrib, IOVector Double)
-> IO (IOVector Contrib, IOVector Double))
-> IO (IOVector Contrib, IOVector Double)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (IOVector Contrib, IOVector Double)
p -> (IOVector Contrib, IOVector Double)
-> IO (IOVector Contrib, IOVector Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOVector Contrib, IOVector Double)
p
Maybe (IOVector Contrib, IOVector Double)
Nothing -> String -> IO (IOVector Contrib, IOVector Double)
forall a. HasCallStack => String -> a
error (String -> IO (IOVector Contrib, IOVector Double))
-> String -> IO (IOVector Contrib, IOVector Double)
forall a b. (a -> b) -> a -> b
$ String
"Thread array not initialised: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)
String -> IO ()
evlog String
"dual allocated"
d -> BuildState -> IO ()
backprop d
adj BuildState
threads
String -> IO ()
evlog String
"dual bped"
JobDescr -> BuildState -> IO ()
resolve JobDescr
finaljob BuildState
threads
String -> IO ()
evlog String
"dual resolved"
IOVector Double
adjarr0 <- (IOVector Contrib, IOVector Double) -> IOVector Double
forall a b. (a, b) -> b
snd ((IOVector Contrib, IOVector Double) -> IOVector Double)
-> IO (IOVector Contrib, IOVector Double) -> IO (IOVector Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) (IOVector Contrib, IOVector Double)
-> Int -> IO (IOVector Contrib, IOVector Double)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read BuildState
MVector (PrimState IO) (IOVector Contrib, IOVector Double)
threads Int
0
Vector Double
res <- MVector (PrimState IO) Double -> IO (Vector Double)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.freeze IOVector Double
MVector (PrimState IO) Double
adjarr0
String -> IO ()
evlog String
"dual frozen"
Vector Double -> IO (Vector Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
res
data Structure = Structure MonoType DataTypes
deriving (Int -> Structure -> ShowS
[Structure] -> ShowS
Structure -> String
(Int -> Structure -> ShowS)
-> (Structure -> String)
-> ([Structure] -> ShowS)
-> Show Structure
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Structure -> ShowS
showsPrec :: Int -> Structure -> ShowS
$cshow :: Structure -> String
show :: Structure -> String
$cshowList :: [Structure] -> ShowS
showList :: [Structure] -> ShowS
Show)
instance NFData Structure where
rnf :: Structure -> ()
rnf (Structure SimpleType 'Mono
m DataTypes
ds) = SimpleType 'Mono -> ()
forall a. NFData a => a -> ()
rnf SimpleType 'Mono
m () -> () -> ()
forall a b. a -> b -> b
`seq` String -> ()
forall a. NFData a => a -> ()
rnf (DataTypes -> String
forall a. Show a => a -> String
show DataTypes
ds)
structureFromType :: Q Type -> Q Structure
structureFromType :: Q Type -> Q Structure
structureFromType Q Type
ty = Q Type
ty Q Type -> (Type -> Q Structure) -> Q Structure
forall a b. Q a -> (a -> Q b) -> Q b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((SimpleType 'Mono, DataTypes) -> Structure)
-> Q (SimpleType 'Mono, DataTypes) -> Q Structure
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SimpleType 'Mono -> DataTypes -> Structure)
-> (SimpleType 'Mono, DataTypes) -> Structure
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SimpleType 'Mono -> DataTypes -> Structure
Structure) (Q (SimpleType 'Mono, DataTypes) -> Q Structure)
-> (Type -> Q (SimpleType 'Mono, DataTypes)) -> Type -> Q Structure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Q (SimpleType 'Mono, DataTypes)
exploreRecursiveType
structureFromTypeable :: Typeable a => Proxy a -> Q Structure
structureFromTypeable :: forall a. Typeable a => Proxy a -> Q Structure
structureFromTypeable = Q Type -> Q Structure
structureFromType (Q Type -> Q Structure)
-> (Proxy a -> Q Type) -> Proxy a -> Q Structure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeRep -> Q Type
typeRepToType (TypeRep -> Q Type) -> (Proxy a -> TypeRep) -> Proxy a -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy a -> TypeRep
forall {k} (proxy :: k -> *) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep
where
typeRepToType :: TypeRep -> Q Type
typeRepToType :: TypeRep -> Q Type
typeRepToType TypeRep
tr = do
let (TyCon
con, [TypeRep]
args) = TypeRep -> (TyCon, [TypeRep])
splitTyConApp TypeRep
tr
name :: Name
name = OccName -> NameFlavour -> Name
Name (String -> OccName
OccName (TyCon -> String
tyConName TyCon
con)) (NameSpace -> PkgName -> ModName -> NameFlavour
NameG NameSpace
TcClsName (String -> PkgName
PkgName (TyCon -> String
tyConPackage TyCon
con)) (String -> ModName
ModName (TyCon -> String
tyConModule TyCon
con)))
[Type]
resultArgs <- (TypeRep -> Q Type) -> [TypeRep] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TypeRep -> Q Type
typeRepToType [TypeRep]
args
Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
name) [Type]
resultArgs)
data Morphic = Poly | Mono
deriving (Int -> Morphic -> ShowS
[Morphic] -> ShowS
Morphic -> String
(Int -> Morphic -> ShowS)
-> (Morphic -> String) -> ([Morphic] -> ShowS) -> Show Morphic
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Morphic -> ShowS
showsPrec :: Int -> Morphic -> ShowS
$cshow :: Morphic -> String
show :: Morphic -> String
$cshowList :: [Morphic] -> ShowS
showList :: [Morphic] -> ShowS
Show)
data SimpleType morphic where
DiscreteST :: SimpleType mf
ScalarST :: SimpleType mf
VarST :: Name -> SimpleType 'Poly
ConST :: Name -> [SimpleType mf] -> SimpleType mf
deriving instance Show (SimpleType mf)
deriving instance Eq (SimpleType mf)
deriving instance Ord (SimpleType mf)
deriving instance Lift (SimpleType mf)
type PolyType = SimpleType 'Poly
type MonoType = SimpleType 'Mono
instance NFData (SimpleType mf) where
rnf :: SimpleType mf -> ()
rnf SimpleType mf
DiscreteST = ()
rnf SimpleType mf
ScalarST = ()
rnf (VarST Name
n) = Name
n Name -> () -> ()
forall a b. a -> b -> b
`seq` ()
rnf (ConST Name
n [SimpleType mf]
ts) = Name
n Name -> () -> ()
forall a b. a -> b -> b
`seq` [SimpleType mf] -> ()
forall a. NFData a => a -> ()
rnf [SimpleType mf]
ts
discreteTypeNames :: [Name]
discreteTypeNames :: [Name]
discreteTypeNames =
[''Int, ''Int8, ''Int16, ''Int32, ''Int64
,''Word, ''Word8, ''Word16, ''Word32, ''Word64
,''Char, ''Bool]
summariseType :: MonadFail m => Type -> m PolyType
summariseType :: forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType = \case
ConT Name
n
| Name
n Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
discreteTypeNames
-> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return PolyType
forall (mf :: Morphic). SimpleType mf
DiscreteST
| Name
n Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Double -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return PolyType
forall (mf :: Morphic). SimpleType mf
ScalarST
| Name
n Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Float -> String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Only Double is an active type for now (Float isn't)"
| Bool
otherwise -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST Name
n []
VarT Name
n -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> PolyType
VarST Name
n
ParensT Type
t -> Type -> m PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
t
TupleT Int
k -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST (Int -> Name
tupleTypeName Int
k) []
Type
ListT -> PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST ''[] []
t :: Type
t@AppT{} -> do
let (Type
hd, [Type]
args) = Type -> (Type, [Type])
collectApps Type
t
PolyType
hd' <- Type -> m PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
hd
[PolyType]
args' <- (Type -> m PolyType) -> [Type] -> m [PolyType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> m PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType [Type]
args
PolyType -> [PolyType] -> m PolyType
forall (m :: * -> *).
MonadFail m =>
PolyType -> [PolyType] -> m PolyType
smartAppsST PolyType
hd' [PolyType]
args'
Type
t -> String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"Unsupported type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Ppr a => a -> String
pprint Type
t
where
smartAppsST :: MonadFail m => PolyType -> [PolyType] -> m PolyType
smartAppsST :: forall (m :: * -> *).
MonadFail m =>
PolyType -> [PolyType] -> m PolyType
smartAppsST PolyType
DiscreteST [PolyType]
_ = String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"Discrete type does not take type parameters"
smartAppsST PolyType
ScalarST [PolyType]
_ = String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"'Double' does not take type parameters"
smartAppsST (VarST Name
n) [PolyType]
_ = String -> m PolyType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m PolyType) -> String -> m PolyType
forall a b. (a -> b) -> a -> b
$ String
"Higher-rank type variable not supported in reverse AD: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n
smartAppsST (ConST Name
n [PolyType]
as) [PolyType]
bs = PolyType -> m PolyType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PolyType -> m PolyType) -> PolyType -> m PolyType
forall a b. (a -> b) -> a -> b
$ Name -> [PolyType] -> PolyType
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST Name
n ([PolyType]
as [PolyType] -> [PolyType] -> [PolyType]
forall a. [a] -> [a] -> [a]
++ [PolyType]
bs)
toMonoType :: Applicative f => (Name -> f MonoType) -> PolyType -> f MonoType
toMonoType :: forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType Name -> f (SimpleType 'Mono)
_ PolyType
DiscreteST = SimpleType 'Mono -> f (SimpleType 'Mono)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SimpleType 'Mono
forall (mf :: Morphic). SimpleType mf
DiscreteST
toMonoType Name -> f (SimpleType 'Mono)
_ PolyType
ScalarST = SimpleType 'Mono -> f (SimpleType 'Mono)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SimpleType 'Mono
forall (mf :: Morphic). SimpleType mf
ScalarST
toMonoType Name -> f (SimpleType 'Mono)
f (VarST Name
n) = Name -> f (SimpleType 'Mono)
f Name
n
toMonoType Name -> f (SimpleType 'Mono)
f (ConST Name
n [PolyType]
ts) = Name -> [SimpleType 'Mono] -> SimpleType 'Mono
forall (mf :: Morphic). Name -> [SimpleType mf] -> SimpleType mf
ConST Name
n ([SimpleType 'Mono] -> SimpleType 'Mono)
-> f [SimpleType 'Mono] -> f (SimpleType 'Mono)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PolyType -> f (SimpleType 'Mono))
-> [PolyType] -> f [SimpleType 'Mono]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType Name -> f (SimpleType 'Mono)
f) [PolyType]
ts
type DataTypes = Map (Name, [MonoType]) [(Name, [MonoType])]
exploreType :: Map Name [MonoType] -> Set Name -> MonoType -> Q DataTypes
exploreType :: Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
stack Set Name
tysynstack
| Map Name [SimpleType 'Mono] -> Int
forall k a. Map k a -> Int
Map.size Map Name [SimpleType 'Mono]
stack Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
20 = \SimpleType 'Mono
_ -> String -> Q DataTypes
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Very deep data type hierarchy (depth 20); polymorphic recursion?"
| Set Name -> Int
forall a. Set a -> Int
Set.size Set Name
tysynstack Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
20 = \SimpleType 'Mono
_ -> String -> Q DataTypes
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Very deep type synonym hierarchy (depth 20); self-recursive type synonym?"
exploreType Map Name [SimpleType 'Mono]
stack Set Name
tysynstack = \case
SimpleType 'Mono
DiscreteST -> DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return DataTypes
forall a. Monoid a => a
mempty
SimpleType 'Mono
ScalarST -> DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return DataTypes
forall a. Monoid a => a
mempty
ConST Name
tyname [SimpleType 'Mono]
argtys
| Just [SimpleType 'Mono]
prevargtys <- Name -> Map Name [SimpleType 'Mono] -> Maybe [SimpleType 'Mono]
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
tyname Map Name [SimpleType 'Mono]
stack
, [SimpleType 'Mono]
argtys [SimpleType 'Mono] -> [SimpleType 'Mono] -> Bool
forall a. Eq a => a -> a -> Bool
== [SimpleType 'Mono]
prevargtys -> DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return DataTypes
forall a. Monoid a => a
mempty
| Bool
otherwise -> do
Dec
typedecl <- Name -> Q Info
reify Name
tyname Q Info -> (Info -> Q Dec) -> Q Dec
forall a b. Q a -> (a -> Q b) -> Q b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
TyConI Dec
decl -> Dec -> Q Dec
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
decl
Info
info -> String -> Q Dec
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Dec) -> String -> Q Dec
forall a b. (a -> b) -> a -> b
$ String
"Name " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a lifted type name: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Show a => a -> String
show Info
info
let analyseConstructor :: [Name] -> Con -> Q ((Name, [SimpleType 'Mono]), DataTypes)
analyseConstructor [Name]
tyvars Con
constr
| [Name] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
tyvars Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
argtys = do
(Name
conname, [Type]
fieldtys) <- case Con
constr of
NormalC Name
conname [BangType]
fieldtys -> (Name, [Type]) -> Q (Name, [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
conname, (BangType -> Type) -> [BangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\( Bang
_,Type
ty) -> Type
ty) [BangType]
fieldtys)
RecC Name
conname [VarBangType]
fieldtys -> (Name, [Type]) -> Q (Name, [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
conname, (VarBangType -> Type) -> [VarBangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\(Name
_,Bang
_,Type
ty) -> Type
ty) [VarBangType]
fieldtys)
InfixC (Bang
_, Type
ty1) Name
conname (Bang
_, Type
ty2) -> (Name, [Type]) -> Q (Name, [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
conname, [Type
ty1, Type
ty2])
Con
_ -> String -> Q (Name, [Type])
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (Name, [Type])) -> String -> Q (Name, [Type])
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor format on data: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
constr
[PolyType]
fieldtys' <- (Type -> Q PolyType) -> [Type] -> Q [PolyType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Q PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType [Type]
fieldtys
[SimpleType 'Mono]
fieldtys'' <- (PolyType -> Q (SimpleType 'Mono))
-> [PolyType] -> Q [SimpleType 'Mono]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name
-> [Name] -> [SimpleType 'Mono] -> PolyType -> Q (SimpleType 'Mono)
monomorphiseField Name
tyname [Name]
tyvars [SimpleType 'Mono]
argtys) [PolyType]
fieldtys'
let stack' :: Map Name [SimpleType 'Mono]
stack' = Name
-> [SimpleType 'Mono]
-> Map Name [SimpleType 'Mono]
-> Map Name [SimpleType 'Mono]
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Name
tyname [SimpleType 'Mono]
argtys Map Name [SimpleType 'Mono]
stack
[DataTypes]
typesets <- (SimpleType 'Mono -> Q DataTypes)
-> [SimpleType 'Mono] -> Q [DataTypes]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
stack' Set Name
forall a. Monoid a => a
mempty) [SimpleType 'Mono]
fieldtys''
((Name, [SimpleType 'Mono]), DataTypes)
-> Q ((Name, [SimpleType 'Mono]), DataTypes)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Name
conname, [SimpleType 'Mono]
fieldtys''), ((Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])])
-> [DataTypes] -> DataTypes
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual [DataTypes]
typesets)
| Bool
otherwise = String -> Q ((Name, [SimpleType 'Mono]), DataTypes)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ((Name, [SimpleType 'Mono]), DataTypes))
-> String -> Q ((Name, [SimpleType 'Mono]), DataTypes)
forall a b. (a -> b) -> a -> b
$ String
"Type not fully applied: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname
case Dec
typedecl of
NewtypeD [] Name
_ [TyVarBndr ()]
tyvars Maybe Type
_ Con
constr [DerivClause]
_ -> do
((Name, [SimpleType 'Mono])
condescr, DataTypes
typeset) <- [Name] -> Con -> Q ((Name, [SimpleType 'Mono]), DataTypes)
analyseConstructor ((TyVarBndr () -> Name) -> [TyVarBndr ()] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr () -> Name
tvbName [TyVarBndr ()]
tyvars) Con
constr
DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (((Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])])
-> (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> DataTypes
-> DataTypes
forall k a.
Ord k =>
(k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual (Name
tyname, [SimpleType 'Mono]
argtys) [(Name, [SimpleType 'Mono])
condescr] DataTypes
typeset)
DataD [] Name
_ [TyVarBndr ()]
tyvars Maybe Type
_ [Con]
constrs [DerivClause]
_ -> do
([(Name, [SimpleType 'Mono])]
condescrs, [DataTypes]
typesets) <- [((Name, [SimpleType 'Mono]), DataTypes)]
-> ([(Name, [SimpleType 'Mono])], [DataTypes])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((Name, [SimpleType 'Mono]), DataTypes)]
-> ([(Name, [SimpleType 'Mono])], [DataTypes]))
-> Q [((Name, [SimpleType 'Mono]), DataTypes)]
-> Q ([(Name, [SimpleType 'Mono])], [DataTypes])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Con -> Q ((Name, [SimpleType 'Mono]), DataTypes))
-> [Con] -> Q [((Name, [SimpleType 'Mono]), DataTypes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Name] -> Con -> Q ((Name, [SimpleType 'Mono]), DataTypes)
analyseConstructor ((TyVarBndr () -> Name) -> [TyVarBndr ()] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr () -> Name
tvbName [TyVarBndr ()]
tyvars)) [Con]
constrs
let typeset :: DataTypes
typeset = ((Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])])
-> [DataTypes] -> DataTypes
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual [DataTypes]
typesets
DataTypes -> Q DataTypes
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (((Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])])
-> (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> DataTypes
-> DataTypes
forall k a.
Ord k =>
(k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWithKey (Name, [SimpleType 'Mono])
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
-> [(Name, [SimpleType 'Mono])]
forall {a} {a}. (Eq a, Show a, Show a) => a -> a -> a -> a
mergeIfEqual (Name
tyname, [SimpleType 'Mono]
argtys) [(Name, [SimpleType 'Mono])]
condescrs DataTypes
typeset)
TySynD Name
_ [TyVarBndr ()]
tyvars Type
rhs -> do
PolyType
srhs <- Type -> Q PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
rhs
SimpleType 'Mono
mrhs <- Name
-> [Name] -> [SimpleType 'Mono] -> PolyType -> Q (SimpleType 'Mono)
monomorphiseField Name
tyname ((TyVarBndr () -> Name) -> [TyVarBndr ()] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr () -> Name
tvbName [TyVarBndr ()]
tyvars) [SimpleType 'Mono]
argtys PolyType
srhs
Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
stack (Name -> Set Name -> Set Name
forall a. Ord a => a -> Set a -> Set a
Set.insert Name
tyname Set Name
tysynstack) SimpleType 'Mono
mrhs
Dec
_ -> String -> Q DataTypes
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q DataTypes) -> String -> Q DataTypes
forall a b. (a -> b) -> a -> b
$ String
"Type not supported: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" (not simple \
\newtype or data)"
where
monomorphiseField :: Name -> [Name] -> [MonoType] -> PolyType -> Q MonoType
monomorphiseField :: Name
-> [Name] -> [SimpleType 'Mono] -> PolyType -> Q (SimpleType 'Mono)
monomorphiseField Name
tyname [Name]
typarams [SimpleType 'Mono]
argtys =
(Name -> Q (SimpleType 'Mono)) -> PolyType -> Q (SimpleType 'Mono)
forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType ((Name -> Q (SimpleType 'Mono))
-> PolyType -> Q (SimpleType 'Mono))
-> (Name -> Q (SimpleType 'Mono))
-> PolyType
-> Q (SimpleType 'Mono)
forall a b. (a -> b) -> a -> b
$ \Name
n -> case Name -> [(Name, SimpleType 'Mono)] -> Maybe (SimpleType 'Mono)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
n ([Name] -> [SimpleType 'Mono] -> [(Name, SimpleType 'Mono)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
typarams [SimpleType 'Mono]
argtys) of
Just SimpleType 'Mono
mt -> SimpleType 'Mono -> Q (SimpleType 'Mono)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return SimpleType 'Mono
mt
Maybe (SimpleType 'Mono)
Nothing -> String -> Q (SimpleType 'Mono)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (SimpleType 'Mono)) -> String -> Q (SimpleType 'Mono)
forall a b. (a -> b) -> a -> b
$ String
"Type variable out of scope in definition of \
\data type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n
mergeIfEqual :: a -> a -> a -> a
mergeIfEqual a
key a
v1 a
v2
| a
v1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v2 = a
v1
| Bool
otherwise = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"Conflicting explorations of the same data type!\n\
\Key: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
key String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n\
\Val 1: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
v1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n\
\Val 2: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
v2
exploreRecursiveType :: Type -> Q (MonoType, DataTypes)
exploreRecursiveType :: Type -> Q (SimpleType 'Mono, DataTypes)
exploreRecursiveType Type
tau = do
PolyType
sty <- Type -> Q PolyType
forall (m :: * -> *). MonadFail m => Type -> m PolyType
summariseType Type
tau
SimpleType 'Mono
mty <- (Name -> Q (SimpleType 'Mono)) -> PolyType -> Q (SimpleType 'Mono)
forall (f :: * -> *).
Applicative f =>
(Name -> f (SimpleType 'Mono)) -> PolyType -> f (SimpleType 'Mono)
toMonoType (\Name
n -> String -> Q (SimpleType 'Mono)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (SimpleType 'Mono)) -> String -> Q (SimpleType 'Mono)
forall a b. (a -> b) -> a -> b
$ String
"Reverse AD input or output type is polymorphic \
\(contains type variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")")
PolyType
sty
DataTypes
dtypes <- Map Name [SimpleType 'Mono]
-> Set Name -> SimpleType 'Mono -> Q DataTypes
exploreType Map Name [SimpleType 'Mono]
forall a. Monoid a => a
mempty Set Name
forall a. Monoid a => a
mempty SimpleType 'Mono
mty
(SimpleType 'Mono, DataTypes) -> Q (SimpleType 'Mono, DataTypes)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (SimpleType 'Mono
mty, DataTypes
dtypes)
reverseAD :: forall a b. (Typeable a, Typeable b)
=> Code Q (a -> b)
-> Code Q (a -> (b, b -> a))
reverseAD :: forall a b.
(Typeable a, Typeable b) =>
Code Q (a -> b) -> Code Q (a -> (b, b -> a))
reverseAD = Q Structure
-> Q Structure -> Code Q (a -> b) -> Code Q (a -> (b, b -> a))
forall a b.
Q Structure
-> Q Structure -> Code Q (a -> b) -> Code Q (a -> (b, b -> a))
reverseAD' (Proxy a -> Q Structure
forall a. Typeable a => Proxy a -> Q Structure
structureFromTypeable (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a)) (Proxy b -> Q Structure
forall a. Typeable a => Proxy a -> Q Structure
structureFromTypeable (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @b))
reverseAD' :: forall a b.
Q Structure
-> Q Structure
-> Code Q (a -> b)
-> Code Q (a -> (b, b -> a))
reverseAD' :: forall a b.
Q Structure
-> Q Structure -> Code Q (a -> b) -> Code Q (a -> (b, b -> a))
reverseAD' Q Structure
inpStruc Q Structure
outStruc (Code Q (a -> b) -> Q Exp
forall a (m :: * -> *). Quote m => Code m a -> m Exp
unTypeCode -> Q Exp
inputCode) =
Q Exp -> Code Q (a -> (b, b -> a))
forall a (m :: * -> *). Quote m => m Exp -> Code m a
unsafeCodeCoerce (Q Exp -> Code Q (a -> (b, b -> a)))
-> Q Exp -> Code Q (a -> (b, b -> a))
forall a b. (a -> b) -> a -> b
$ do
Structure
inpStruc' <- Q Structure
inpStruc
Structure
outStruc' <- Q Structure
outStruc
Exp
ex <- Q Exp
inputCode
Structure -> Structure -> Exp -> Q Exp
transform Structure
inpStruc' Structure
outStruc' Exp
ex
transform :: Structure -> Structure -> TH.Exp -> Q TH.Exp
transform :: Structure -> Structure -> Exp -> Q Exp
transform Structure
inpStruc Structure
outStruc Exp
expr = do
Exp
expr' <- Exp -> Q Exp
translate Exp
expr
case Exp
expr' of
ELam Pat
pat Exp
body -> Structure -> Structure -> Pat -> Exp -> Q Exp
transform' Structure
inpStruc Structure
outStruc Pat
pat Exp
body
Exp
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"Top-level expression in reverseAD must be lambda, but is: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Show a => a -> String
show Exp
expr'
transform' :: Structure -> Structure -> Pat -> Source.Exp -> Q TH.Exp
transform' :: Structure -> Structure -> Pat -> Exp -> Q Exp
transform' Structure
inpStruc Structure
outStruc Pat
pat Exp
expr = do
Structure
_ <- IO Structure -> Q Structure
forall a. IO a -> Q a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Structure -> Q Structure) -> IO Structure -> Q Structure
forall a b. (a -> b) -> a -> b
$ Structure -> IO Structure
forall a. a -> IO a
evaluate (Structure -> Structure
forall a. NFData a => a -> a
force Structure
inpStruc)
Structure
_ <- IO Structure -> Q Structure
forall a. IO a -> Q a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Structure -> Q Structure) -> IO Structure -> Q Structure
forall a b. (a -> b) -> a -> b
$ Structure -> IO Structure
forall a. a -> IO a
evaluate (Structure -> Structure
forall a. NFData a => a -> a
force Structure
outStruc)
Set Name
patbound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
Name
argvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"arg"
Name
rebuildvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"rebuild"
Name
rebuild'var <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"rebuild'"
Name
outvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"out"
Name
outjdvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"outjd"
Name
outnextjivar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"outnextji"
Name
primalvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"primal"
Name
primal'var <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"primal'"
Name
dualvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"dual"
Name
dual'var <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"dual'"
Name
adjvar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"adjoint"
[| \ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
argvar) ->
let ($(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
outjdvar), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
outnextjivar), ($(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
rebuildvar), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
primalvar), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
dualvar))) =
runFwdM $ do
($(Pat -> Q Pat
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Pat
pat), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
rebuild'var)) <- $(Structure -> Exp -> Q Exp
forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
interleave Structure
inpStruc (Name -> Exp
VarE Name
argvar))
$(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
outvar) <- $(Set Name -> Exp -> Q Exp
ddr Set Name
patbound Exp
expr)
($(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
primal'var), $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
dual'var)) <- mpure $(Structure -> Exp -> Q Exp
forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
deinterleave Structure
outStruc (Name -> Exp
VarE Name
outvar))
mpure ($(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
rebuild'var), $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
primal'var), $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
dual'var))
in ($(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
primalvar)
,\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
adjvar) ->
$(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
rebuildvar) $! dualpass $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
outjdvar) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
outnextjivar) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
dualvar) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
adjvar)) |]
data DN = DN {-# UNPACK #-} !Double
{-# UNPACK #-} !NID
!Contrib
type Env = Set Name
ddr :: Env -> Source.Exp -> Q TH.Exp
ddr :: Set Name -> Exp -> Q Exp
ddr Set Name
env = \case
EVar Name
name
| Name
name Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Name
env -> [| mpure $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'fromIntegral -> [| mpure fromIntegralOp |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'negate -> do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) -> applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) negate (\_ -> (-1))) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'sqrt -> do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) sqrt (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> 1 / (2 * sqrt $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname)))) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'exp -> do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Name
primalname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
applyUnaryOp2 $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) exp (\_ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
primalname) -> $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
primalname))) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'log -> do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) log (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> 1 / $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname))) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'sin -> do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) sin (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> cos $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname))) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'cos -> do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
applyUnaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) cos (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pname) -> - sin $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname))) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '($) -> do
Name
fname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"f"
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Set Name -> Exp -> Q Exp
ddr Set Name
env (Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
fname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
xname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
EVar Name
fname Exp -> Exp -> Exp
`EApp` Name -> Exp
EVar Name
xname)
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(.) -> do
Name
fname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"f"
Name
gname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"g"
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Set Name -> Exp -> Q Exp
ddr Set Name
env (Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
fname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
gname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
xname) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
EVar Name
fname Exp -> Exp -> Exp
`EApp` (Name -> Exp
EVar Name
gname Exp -> Exp -> Exp
`EApp` Name -> Exp
EVar Name
xname))
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(+) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(+) (Bool
False, Bool
False) (\Q Exp
_ Q Exp
_ -> [| (1, 1) |])
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(-) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(-) (Bool
False, Bool
False) (\Q Exp
_ Q Exp
_ -> [| (1, -1) |])
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(*) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(*) (Bool
True, Bool
True) (\Q Exp
x Q Exp
y -> [| ($Q Exp
y, $Q Exp
x) |])
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(/) -> Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp '(/) (Bool
True, Bool
True) (\Q Exp
x Q Exp
y -> [| (recip $Q Exp
y, (- $Q Exp
x) / ($Q Exp
y * $Q Exp
y)) |])
| Name
name Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ['(==), '(/=), '(<), '(>), '(<=), '(>=)] -> do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Name
yname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"y"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
yname) ->
mpure (applyCmpOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
yname)
$(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name)))) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(|*|) ->
String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"The parallel combinator (|*|) should be applied directly; partially applying it is pointless."
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'error -> [| mpure error |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'fst -> [| mpure (\x -> mpure (fst x)) |]
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'snd -> [| mpure (\x -> mpure (snd x)) |]
| Bool
otherwise -> do
Type
typ <- Name -> Q Type
reifyType Name
name
let ([Type]
params, Type
retty) = Type -> ([Type], Type)
unpackFunctionType Type
typ
if (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
isDiscrete [Type]
params Bool -> Bool -> Bool
&& Type -> Bool
isDiscrete Type
retty
then [| mpure $(Int -> Exp -> Q Exp
liftKleisliN ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
params) (Name -> Exp
VarE Name
name)) |]
else String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"Most free variables not supported in reverseAD: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
name String -> ShowS
forall a. [a] -> [a] -> [a]
++
String
" (env = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Set Name -> String
forall a. Show a => a -> String
show Set Name
env String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
ECon Name
name -> do
[Type]
fieldtys <- Name -> Q [Type]
checkDatacon Name
name
[| mpure $(Int -> Exp -> Q Exp
liftKleisliN ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
fieldtys) (Name -> Exp
ConE Name
name)) |]
ELit Lit
lit -> case Lit
lit of
RationalL Rational
_ -> do
Name
iname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"i"
[| do $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
iname) <- fwdmGenId
mpure (DN $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE Lit
lit) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
iname) C0) |]
FloatPrimL Rational
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"float prim?"
DoublePrimL Rational
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"double prim?"
IntegerL Integer
_ -> [| mpure $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE Lit
lit) |]
StringL String
_ -> [| mpure $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE Lit
lit) |]
Lit
_ -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"literal? " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Lit -> String
forall a. Show a => a -> String
show Lit
lit
EVar Name
name `EApp` Exp
e1 `EApp` Exp
e2 | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '(|*|) ->
[| fwdmPar $(Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e1) $(Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e2) |]
EApp Exp
e1 Exp
e2 -> do
(Exp -> Exp
wrap, [Name
funname, Name
argname]) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp
e1, Exp
e2]
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp
wrap (Name -> Exp
VarE Name
funname Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
argname))
ELam Pat
pat Exp
e -> do
Set Name
patbound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
Exp
e' <- Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> Set Name
patbound) Exp
e
[| mpure (\ $(Pat -> Q Pat
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Pat
pat) -> $(Exp -> Q Exp
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e')) |]
ETup [Exp]
es -> do
(Exp -> Exp
wrap, [Name]
vars) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp]
es
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp
wrap (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE 'mpure Exp -> Exp -> Exp
`AppE` [Maybe Exp] -> Exp
TupE ((Name -> Maybe Exp) -> [Name] -> [Maybe Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> (Name -> Exp) -> Name -> Maybe Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE) [Name]
vars))
ECond Exp
e1 Exp
e2 Exp
e3 -> do
Exp
e1' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e1
Name
boolName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"bool"
Exp
e2' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e2
Exp
e3' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e3
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing
[Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
boolName) Exp
e1'
,Exp -> Stmt
NoBindS (Exp -> Exp -> Exp -> Exp
CondE (Name -> Exp
VarE Name
boolName) Exp
e2' Exp
e3')])
ELet [DecGroup]
decs Exp
body -> do
(Exp -> Exp
wrap, [Name]
vars) <- Set Name -> [DecGroup] -> Q (Exp -> Exp, [Name])
ddrDecs Set Name
env [DecGroup]
decs
Exp
body' <- Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList [Name]
vars) Exp
body
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
wrap Exp
body'
ECase Exp
expr [(Pat, Exp)]
matches -> do
(Exp -> Exp
wrap, [Name
evar]) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp
expr]
[(Pat, Exp)]
matches' <- [Q (Pat, Exp)] -> Q [(Pat, Exp)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
[do Set Name
patbound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
Pat -> Q ()
ddrPat Pat
pat
(Pat
pat,) (Exp -> (Pat, Exp)) -> Q Exp -> Q (Pat, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> Set Name
patbound) Exp
rhs
| (Pat
pat, Exp
rhs) <- [(Pat, Exp)]
matches]
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
wrap (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
evar)
[Pat -> Body -> [Dec] -> Match
Match Pat
pat (Exp -> Body
NormalB Exp
rhs) [] | (Pat
pat, Exp
rhs) <- [(Pat, Exp)]
matches']
EList [Exp]
es -> do
(Exp -> Exp
wrap, [Name]
vars) <- Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp]
es
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp
wrap (Name -> Exp
VarE 'mpure Exp -> Exp -> Exp
`AppE` [Exp] -> Exp
ListE ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
vars)))
ESig Exp
e Type
ty ->
[| $(Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
e) :: FwdM $(Type -> Q Type
ddrType Type
ty) |]
ddrNumBinOp :: Name
-> (Bool, Bool)
-> (Q TH.Exp -> Q TH.Exp -> Q TH.Exp)
-> Q TH.Exp
ddrNumBinOp :: Name -> (Bool, Bool) -> (Q Exp -> Q Exp -> Q Exp) -> Q Exp
ddrNumBinOp Name
op (Bool
xused, Bool
yused) Q Exp -> Q Exp -> Q Exp
mkgrad = do
Name
xname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
Name
yname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"y"
Name
pxname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"px"
Name
pyname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"py"
[| mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
xname) ->
mpure (\ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
yname) ->
applyBinaryOp $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
xname) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
yname)
$(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
op)
(\ $(if Bool
xused then Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pxname else Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP)
$(if Bool
yused then Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
pyname else Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP) ->
$(Q Exp -> Q Exp -> Q Exp
mkgrad (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pxname) (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pyname))))) |]
ddrList :: Env -> [Source.Exp] -> Q (TH.Exp -> TH.Exp, [Name])
ddrList :: Set Name -> [Exp] -> Q (Exp -> Exp, [Name])
ddrList Set Name
env [Exp]
es = do
[Name]
names <- (Int -> Q Name) -> [Int] -> Q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
idx -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
idx)) [Int
1 .. [Exp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
es]
[Exp]
rhss <- (Exp -> Q Exp) -> [Exp] -> Q [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Set Name -> Exp -> Q Exp
ddr Set Name
env) [Exp]
es
(Exp -> Exp, [Name]) -> Q (Exp -> Exp, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (\Exp
rest ->
Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing ([Stmt] -> Exp) -> [Stmt] -> Exp
forall a b. (a -> b) -> a -> b
$ [Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
nx) Exp
e | (Name
nx, Exp
e) <- [Name] -> [Exp] -> [(Name, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names [Exp]
rhss] [Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++ [Exp -> Stmt
NoBindS Exp
rest]
,[Name]
names)
ddrDecs :: Env -> [DecGroup] -> Q (TH.Exp -> TH.Exp, [Name])
ddrDecs :: Set Name -> [DecGroup] -> Q (Exp -> Exp, [Name])
ddrDecs Set Name
env [DecGroup]
decs = do
let ddrDecGroups :: Set Name -> [DecGroup] -> Q ([Stmt], [Name])
ddrDecGroups Set Name
_ [] = ([Stmt], [Name]) -> Q ([Stmt], [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])
ddrDecGroups Set Name
env' (DecGroup
g : [DecGroup]
gs) = do
(Stmt
stmt, [Name]
bound) <- Set Name -> DecGroup -> Q (Stmt, [Name])
ddrDecGroup Set Name
env' DecGroup
g
([Stmt]
rest, [Name]
restbound) <- Set Name -> [DecGroup] -> Q ([Stmt], [Name])
ddrDecGroups (Set Name
env' Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList [Name]
bound) [DecGroup]
gs
([Stmt], [Name]) -> Q ([Stmt], [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Stmt
stmt Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt]
rest, [Name]
bound [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ [Name]
restbound)
([Stmt]
stmts, [Name]
declared) <- Set Name -> [DecGroup] -> Q ([Stmt], [Name])
ddrDecGroups Set Name
env [DecGroup]
decs
(Exp -> Exp, [Name]) -> Q (Exp -> Exp, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
(\Exp
body -> Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing ([Stmt] -> Exp) -> [Stmt] -> Exp
forall a b. (a -> b) -> a -> b
$ [Stmt]
stmts [Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++ [Exp -> Stmt
NoBindS Exp
body]
,[Name]
declared)
ddrDecGroup :: Env -> DecGroup -> Q (Stmt, [Name])
ddrDecGroup :: Set Name -> DecGroup -> Q (Stmt, [Name])
ddrDecGroup Set Name
env (DecVar Name
name Maybe Type
msig Exp
rhs) = do
Exp
rhs' <- Set Name -> Exp -> Q Exp
ddr Set Name
env Exp
rhs
Exp
rhs'' <- case Maybe Type
msig of Just Type
sig -> do Type
sig' <- Type -> Q Type
ddrType Type
sig
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Type -> Exp
SigE Exp
rhs' (Type -> Type -> Type
AppT (Name -> Type
ConT ''FwdM) Type
sig')
Maybe Type
Nothing -> Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
rhs'
(Stmt, [Name]) -> Q (Stmt, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
name) Exp
rhs'', [Name
name])
ddrDecGroup Set Name
env (DecMutGroup [(Name, Maybe Type, Pat, Exp)]
lams) = do
let names :: [Name]
names = [Name
name | (Name
name, Maybe Type
_, Pat
_, Exp
_) <- [(Name, Maybe Type, Pat, Exp)]
lams]
[Dec]
decs <- [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Q [Dec]] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
[do Pat -> Q ()
ddrPat Pat
pat
Set Name
bound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
Exp
rhs' <- Set Name -> Exp -> Q Exp
ddr (Set Name
env Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList [Name]
names Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> Set Name
bound) Exp
rhs
let dec :: Dec
dec = Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
name) (Exp -> Body
NormalB ([Pat] -> Exp -> Exp
LamE [Pat
pat] Exp
rhs')) []
case Maybe Type
msig of
Just Type
sig -> do Type
sig' <- Type -> Q Type
ddrType Type
sig
[Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Name -> Type -> Dec
SigD Name
name Type
sig', Dec
dec]
Maybe Type
Nothing -> [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Dec
dec]
| (Name
name, Maybe Type
msig, Pat
pat, Exp
rhs) <- [(Name, Maybe Type, Pat, Exp)]
lams]
(Stmt, [Name]) -> Q (Stmt, [Name])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Stmt
LetS [Dec]
decs, [Name]
names)
ddrPat :: Pat -> Q ()
ddrPat :: Pat -> Q ()
ddrPat = \case
LitP{} -> String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Literals in patterns unsupported"
VarP{} -> () -> Q ()
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
TupP [Pat]
ps -> (Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
ps
UnboxedTupP [Pat]
ps -> (Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
ps
p :: Pat
p@UnboxedSumP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Unboxed sums" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
p :: Pat
p@(ConP Name
name [Type]
tyapps [Pat]
args)
| Bool -> Bool
not ([Type] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type]
tyapps) -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Type applications in patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
| Bool
otherwise -> do
[Type]
_ <- Name -> Q [Type]
checkDatacon Name
name
(Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
args
InfixP Pat
p1 Name
name Pat
p2 -> Pat -> Q ()
ddrPat (Name -> [Type] -> [Pat] -> Pat
ConP Name
name [] [Pat
p1, Pat
p2])
p :: Pat
p@UInfixP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"UInfix patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
ParensP Pat
p -> Pat -> Q ()
ddrPat Pat
p
p :: Pat
p@TildeP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Irrefutable patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
p :: Pat
p@BangP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Bang patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
AsP Name
_ Pat
p -> Pat -> Q ()
ddrPat Pat
p
Pat
WildP -> () -> Q ()
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
p :: Pat
p@RecP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Records" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
ListP [Pat]
ps -> (Pat -> Q ()) -> [Pat] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Pat -> Q ()
ddrPat [Pat]
ps
p :: Pat
p@SigP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Type signatures in patterns, because then I need to rewrite types and I'm lazy" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
p :: Pat
p@ViewP{} -> String -> Maybe String -> Q ()
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"View patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
ddrType :: Type -> Q Type
ddrType :: Type -> Q Type
ddrType = \Type
ty ->
case Type -> Either Type Type
go Type
ty of
Left Type
bad -> String -> Q Type
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Type) -> String -> Q Type
forall a b. (a -> b) -> a -> b
$ String
"Don't know how to differentiate (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
show Type
bad String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"), which is a \
\part of the type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
show Type
ty
Right Type
res -> Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
res
where
go :: Type -> Either Type Type
go :: Type -> Either Type Type
go (ConT Name
name)
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Double = Type -> Either Type Type
forall a b. b -> Either a b
Right (Name -> Type
ConT ''DN)
| Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Int = Type -> Either Type Type
forall a b. b -> Either a b
Right (Name -> Type
ConT ''Int)
go (Type
ArrowT `AppT` Type
t1 `AppT` Type
t) = do
Type
t1' <- Type -> Either Type Type
go Type
t1
Type
t' <- Type -> Either Type Type
go Type
t
Type -> Either Type Type
forall a. a -> Either Type a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Either Type Type) -> Type -> Either Type Type
forall a b. (a -> b) -> a -> b
$ Type
ArrowT Type -> Type -> Type
`AppT` Type
t1' Type -> Type -> Type
`AppT` (Name -> Type
ConT ''FwdM Type -> Type -> Type
`AppT` Type
t')
go (Type
MulArrowT `AppT` PromotedT Name
multi `AppT` Type
t1 `AppT` Type
t)
| Name
multi Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'Many = Type -> Either Type Type
go (Type
ArrowT Type -> Type -> Type
`AppT` Type
t1 Type -> Type -> Type
`AppT` Type
t)
go Type
ty =
case Type -> (Type, [Type])
collectApps Type
ty of
(TupleT Int
n, [Type]
args) | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n ->
(Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT Int
n) ([Type] -> Type) -> Either Type [Type] -> Either Type Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> Either Type Type) -> [Type] -> Either Type [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Type -> Either Type Type
go [Type]
args
(ConT Name
name, [Type]
args) ->
(Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
name) ([Type] -> Type) -> Either Type [Type] -> Either Type Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> Either Type Type) -> [Type] -> Either Type [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Type -> Either Type Type
go [Type]
args
(Type, [Type])
_ -> Type -> Either Type Type
forall a b. a -> Either a b
Left Type
ty
interleave :: Quote m => Structure -> TH.Exp -> m TH.Exp
interleave :: forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
interleave (Structure SimpleType 'Mono
monotype DataTypes
dtypemap) Exp
input = do
Map (Name, [SimpleType 'Mono]) Name
helpernames <- [((Name, [SimpleType 'Mono]), Name)]
-> Map (Name, [SimpleType 'Mono]) Name
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([((Name, [SimpleType 'Mono]), Name)]
-> Map (Name, [SimpleType 'Mono]) Name)
-> m [((Name, [SimpleType 'Mono]), Name)]
-> m (Map (Name, [SimpleType 'Mono]) Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
[m ((Name, [SimpleType 'Mono]), Name)]
-> m [((Name, [SimpleType 'Mono]), Name)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [((Name
n, [SimpleType 'Mono]
ts),) (Name -> ((Name, [SimpleType 'Mono]), Name))
-> m Name -> m ((Name, [SimpleType 'Mono]), Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Name -> [SimpleType 'Mono] -> String
genDataNameTag String
"inter" Name
n [SimpleType 'Mono]
ts)
| (Name
n, [SimpleType 'Mono]
ts) <- DataTypes -> [(Name, [SimpleType 'Mono])]
forall k a. Map k a -> [k]
Map.keys DataTypes
dtypemap]
[(Name, Exp)]
helperfuns <- [m (Name, Exp)] -> m [(Name, Exp)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [(Map (Name, [SimpleType 'Mono]) Name
helpernames Map (Name, [SimpleType 'Mono]) Name
-> (Name, [SimpleType 'Mono]) -> Name
forall k a. Ord k => Map k a -> k -> a
Map.! (Name
n, [SimpleType 'Mono]
ts),) (Exp -> (Name, Exp)) -> m Exp -> m (Name, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
interleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs
| ((Name
n, [SimpleType 'Mono]
ts), [(Name, [SimpleType 'Mono])]
constrs) <- DataTypes
-> [((Name, [SimpleType 'Mono]), [(Name, [SimpleType 'Mono])])]
forall k a. Map k a -> [(k, a)]
Map.assocs DataTypes
dtypemap]
Exp
mainfun <- Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
interleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames SimpleType 'Mono
monotype
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Dec] -> Exp -> Exp
LetE [Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
name) (Exp -> Body
NormalB Exp
fun) []
| (Name
name, Exp
fun) <- [(Name, Exp)]
helperfuns] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp
mainfun Exp -> Exp -> Exp
`AppE` Exp
input
interleaveData :: Quote m => Map (Name, [MonoType]) Name -> [(Name, [MonoType])] -> m TH.Exp
interleaveData :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
interleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs = do
Name
inputvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"input"
let maxn :: Int
maxn = [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum (((Name, [SimpleType 'Mono]) -> Int)
-> [(Name, [SimpleType 'Mono])] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SimpleType 'Mono] -> Int)
-> ((Name, [SimpleType 'Mono]) -> [SimpleType 'Mono])
-> (Name, [SimpleType 'Mono])
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [SimpleType 'Mono]) -> [SimpleType 'Mono]
forall a b. (a, b) -> b
snd) [(Name, [SimpleType 'Mono])]
constrs)
[Name]
allinpvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"inp" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
[Name]
allpostvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"inp'" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
[Name]
allrebuildvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"reb" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
[Exp]
bodies <- [m Exp] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
[do
let inpvars :: [Name]
inpvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allinpvars
postvars :: [Name]
postvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allpostvars
rebuildvars :: [Name]
rebuildvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allrebuildvars
[Exp]
exps <- (SimpleType 'Mono -> m Exp) -> [SimpleType 'Mono] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
interleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames) [SimpleType 'Mono]
fieldtys
Name
arrname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"arr"
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ Maybe ModName -> [Stmt] -> Exp
DoE Maybe ModName
forall a. Maybe a
Nothing ([Stmt] -> Exp) -> [Stmt] -> Exp
forall a b. (a -> b) -> a -> b
$
[Pat -> Exp -> Stmt
BindS ([Pat] -> Pat
TupP [Name -> Pat
VarP Name
postvar, Name -> Pat
VarP Name
rebuildvar])
(Exp
expr Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
inpvar)
| (Name
inpvar, Name
postvar, Name
rebuildvar, Exp
expr)
<- [Name] -> [Name] -> [Name] -> [Exp] -> [(Name, Name, Name, Exp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Name]
inpvars [Name]
postvars [Name]
rebuildvars [Exp]
exps]
[Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++
[Exp -> Stmt
NoBindS (Exp -> Stmt) -> Exp -> Stmt
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE 'mpure Exp -> Exp -> Exp
`AppE`
Exp -> Exp -> Exp
pair ((Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
conname) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
postvars))
([Pat] -> Exp -> Exp
LamE [if [Name] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
rebuildvars then Pat
WildP else Name -> Pat
VarP Name
arrname] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
(Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
conname)
[Name -> Exp
VarE Name
f Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
arrname | Name
f <- [Name]
rebuildvars])]
| (Name
conname, [SimpleType 'Mono]
fieldtys) <- [(Name, [SimpleType 'Mono])]
constrs]
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
inputvar] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
inputvar)
[Pat -> Body -> [Dec] -> Match
Match (Name -> [Type] -> [Pat] -> Pat
ConP Name
conname [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
inpvars))
(Exp -> Body
NormalB Exp
body)
[]
| ((Name
conname, [SimpleType 'Mono]
fieldtys), Exp
body) <- [(Name, [SimpleType 'Mono])]
-> [Exp] -> [((Name, [SimpleType 'Mono]), Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Name, [SimpleType 'Mono])]
constrs [Exp]
bodies
, let inpvars :: [Name]
inpvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allinpvars]
interleaveType :: Quote m => Map (Name, [MonoType]) Name -> MonoType -> m TH.Exp
interleaveType :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
interleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames = \case
SimpleType 'Mono
DiscreteST -> do
Name
inpxvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"inpx"
[| \ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
inpxvar) -> mpure ($(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
inpxvar), \_ -> $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
inpxvar)) |]
SimpleType 'Mono
ScalarST -> do
Name
inpxvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"inpx"
Name
ivar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"i"
Name
arrvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"arr"
[| \ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
inpxvar) -> do
$(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
ivar) <- fwdmGenIdInterleave
mpure (DN $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
inpxvar) (NID (JobID 0) $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
ivar)) C0
,\ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
arrvar) -> $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
arrvar) VS.! $(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
ivar)) |]
ConST Name
tyname [SimpleType 'Mono]
argtys ->
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE (Name -> Exp) -> Name -> Exp
forall a b. (a -> b) -> a -> b
$ case (Name, [SimpleType 'Mono])
-> Map (Name, [SimpleType 'Mono]) Name -> Maybe Name
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Name
tyname, [SimpleType 'Mono]
argtys) Map (Name, [SimpleType 'Mono]) Name
helpernames of
Just Name
name -> Name
name
Maybe Name
Nothing -> String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Helper name not defined? " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Name, [SimpleType 'Mono]) -> String
forall a. Show a => a -> String
show (Name
tyname, [SimpleType 'Mono]
argtys)
deinterleave :: Quote m => Structure -> TH.Exp -> m TH.Exp
deinterleave :: forall (m :: * -> *). Quote m => Structure -> Exp -> m Exp
deinterleave (Structure SimpleType 'Mono
monotype DataTypes
dtypemap) Exp
outexp = do
let dtypes :: [(Name, [SimpleType 'Mono])]
dtypes = DataTypes -> [(Name, [SimpleType 'Mono])]
forall k a. Map k a -> [k]
Map.keys DataTypes
dtypemap
Map (Name, [SimpleType 'Mono]) Name
helpernames <- [((Name, [SimpleType 'Mono]), Name)]
-> Map (Name, [SimpleType 'Mono]) Name
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([((Name, [SimpleType 'Mono]), Name)]
-> Map (Name, [SimpleType 'Mono]) Name)
-> m [((Name, [SimpleType 'Mono]), Name)]
-> m (Map (Name, [SimpleType 'Mono]) Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
[m ((Name, [SimpleType 'Mono]), Name)]
-> m [((Name, [SimpleType 'Mono]), Name)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [((Name
n, [SimpleType 'Mono]
ts),) (Name -> ((Name, [SimpleType 'Mono]), Name))
-> m Name -> m ((Name, [SimpleType 'Mono]), Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Name -> [SimpleType 'Mono] -> String
genDataNameTag String
"deinter" Name
n [SimpleType 'Mono]
ts)
| (Name
n, [SimpleType 'Mono]
ts) <- [(Name, [SimpleType 'Mono])]
dtypes]
[(Name, Exp)]
helperfuns <- [m (Name, Exp)] -> m [(Name, Exp)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [(Map (Name, [SimpleType 'Mono]) Name
helpernames Map (Name, [SimpleType 'Mono]) Name
-> (Name, [SimpleType 'Mono]) -> Name
forall k a. Ord k => Map k a -> k -> a
Map.! (Name
n, [SimpleType 'Mono]
ts),) (Exp -> (Name, Exp)) -> m Exp -> m (Name, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
deinterleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs
| ((Name
n, [SimpleType 'Mono]
ts), [(Name, [SimpleType 'Mono])]
constrs) <- DataTypes
-> [((Name, [SimpleType 'Mono]), [(Name, [SimpleType 'Mono])])]
forall k a. Map k a -> [(k, a)]
Map.assocs DataTypes
dtypemap]
Exp
mainfun <- Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
deinterleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames SimpleType 'Mono
monotype
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Dec] -> Exp -> Exp
LetE [Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
name) (Exp -> Body
NormalB Exp
fun) []
| (Name
name, Exp
fun) <- [(Name, Exp)]
helperfuns] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp
mainfun Exp -> Exp -> Exp
`AppE` Exp
outexp
deinterleaveData :: Quote m => Map (Name, [MonoType]) Name -> [(Name, [MonoType])] -> m TH.Exp
deinterleaveData :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name
-> [(Name, [SimpleType 'Mono])] -> m Exp
deinterleaveData Map (Name, [SimpleType 'Mono]) Name
helpernames [(Name, [SimpleType 'Mono])]
constrs = do
Name
dualvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"out"
let maxn :: Int
maxn = [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum (((Name, [SimpleType 'Mono]) -> Int)
-> [(Name, [SimpleType 'Mono])] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SimpleType 'Mono] -> Int)
-> ((Name, [SimpleType 'Mono]) -> [SimpleType 'Mono])
-> (Name, [SimpleType 'Mono])
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [SimpleType 'Mono]) -> [SimpleType 'Mono]
forall a b. (a, b) -> b
snd) [(Name, [SimpleType 'Mono])]
constrs)
[Name]
alldvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"d" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
[Name]
allpvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"p" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
[Name]
allfvars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"f" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
[Name]
allavars <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"a" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
maxn]
Name
bsvar <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"bs"
let composeActions :: [Exp] -> Exp
composeActions [] = [Pat] -> Exp -> Exp
LamE [Pat
WildP] (Name -> Exp
VarE 'return Exp -> Exp -> Exp
`AppE` [Maybe Exp] -> Exp
TupE [])
composeActions [Exp]
l =
[Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
bsvar] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
(Exp -> Exp -> Exp) -> [Exp] -> Exp
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (\Exp
a Exp
b -> Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
a) (Name -> Exp
VarE '(>>)) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
b))
((Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
bsvar) [Exp]
l)
[Exp]
bodies <- [m Exp] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
[do let dvars :: [Name]
dvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
alldvars
pvars :: [Name]
pvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allpvars
fvars :: [Name]
fvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allfvars
avars :: [Name]
avars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
allavars
[Exp]
exps <- (SimpleType 'Mono -> m Exp) -> [SimpleType 'Mono] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
deinterleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames) [SimpleType 'Mono]
fieldtys
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Dec] -> Exp -> Exp
LetE [Pat -> Body -> [Dec] -> Dec
ValD ([Pat] -> Pat
TupP [Name -> Pat
VarP Name
pvar, Name -> Pat
VarP Name
fvar]) (Exp -> Body
NormalB (Exp
expr Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
dvar)) []
| (Name
dvar, Name
pvar, Name
fvar, Exp
expr) <- [Name] -> [Name] -> [Name] -> [Exp] -> [(Name, Name, Name, Exp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Name]
dvars [Name]
pvars [Name]
fvars [Exp]
exps] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp -> Exp -> Exp
pair ((Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
conname) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
pvars))
([Pat] -> Exp -> Exp
LamE [Name -> [Type] -> [Pat] -> Pat
ConP Name
conname [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
avars)] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp -> Type -> Exp
SigE ([Exp] -> Exp
composeActions [Name -> Exp
VarE Name
fvar Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
avar
| (Name
fvar, Name
avar) <- [Name] -> [Name] -> [(Name, Name)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
fvars [Name]
avars])
(Type
MulArrowT Type -> Type -> Type
`AppT` Name -> Type
ConT 'Many Type -> Type -> Type
`AppT` Name -> Type
ConT ''BuildState Type -> Type -> Type
`AppT` (Name -> Type
ConT ''IO Type -> Type -> Type
`AppT` Int -> Type
TupleT Int
0)))
| (Name
conname, [SimpleType 'Mono]
fieldtys) <- [(Name, [SimpleType 'Mono])]
constrs]
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
dualvar] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
dualvar)
[Pat -> Body -> [Dec] -> Match
Match (Name -> [Type] -> [Pat] -> Pat
ConP Name
conname [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
dvars))
(Exp -> Body
NormalB Exp
body)
[]
| ((Name
conname, [SimpleType 'Mono]
fieldtys), Exp
body) <- [(Name, [SimpleType 'Mono])]
-> [Exp] -> [((Name, [SimpleType 'Mono]), Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Name, [SimpleType 'Mono])]
constrs [Exp]
bodies
, let dvars :: [Name]
dvars = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([SimpleType 'Mono] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SimpleType 'Mono]
fieldtys) [Name]
alldvars]
deinterleaveType :: Quote m => Map (Name, [MonoType]) Name -> MonoType -> m TH.Exp
deinterleaveType :: forall (m :: * -> *).
Quote m =>
Map (Name, [SimpleType 'Mono]) Name -> SimpleType 'Mono -> m Exp
deinterleaveType Map (Name, [SimpleType 'Mono]) Name
helpernames = \case
SimpleType 'Mono
DiscreteST -> do
Name
dname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"d"
[| \ $(Name -> m Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
dname) -> ($(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
dname), \_ _ -> return () :: IO ()) |]
SimpleType 'Mono
ScalarST -> do
Name
primalname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"prim"
Name
idname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"id"
Name
cbname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"cb"
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ [Pat] -> Exp -> Exp
LamE [Name -> [Type] -> [Pat] -> Pat
ConP 'DN [] [Name -> Pat
VarP Name
primalname, Name -> Pat
VarP Name
idname, Name -> Pat
VarP Name
cbname]] (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp -> Exp -> Exp
pair (Name -> Exp
VarE Name
primalname)
(Name -> Exp
VarE 'addContrib Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
idname Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
cbname)
ConST Name
tyname [SimpleType 'Mono]
argtys ->
Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE (Name -> Exp) -> Name -> Exp
forall a b. (a -> b) -> a -> b
$ case (Name, [SimpleType 'Mono])
-> Map (Name, [SimpleType 'Mono]) Name -> Maybe Name
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Name
tyname, [SimpleType 'Mono]
argtys) Map (Name, [SimpleType 'Mono]) Name
helpernames of
Just Name
name -> Name
name
Maybe Name
Nothing -> String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Helper name not defined? " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Name, [SimpleType 'Mono]) -> String
forall a. Show a => a -> String
show (Name
tyname, [SimpleType 'Mono]
argtys)
genDataNameTag :: String -> Name -> [MonoType] -> String
genDataNameTag :: String -> Name -> [SimpleType 'Mono] -> String
genDataNameTag String
prefix Name
tyname [SimpleType 'Mono]
argtys = String
prefix String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
goN Name
tyname String -> ShowS
forall a. [a] -> [a] -> [a]
++ (SimpleType 'Mono -> String) -> [SimpleType 'Mono] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Char
'_'Char -> ShowS
forall a. a -> [a] -> [a]
:) ShowS -> (SimpleType 'Mono -> String) -> SimpleType 'Mono -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleType 'Mono -> String
goT) [SimpleType 'Mono]
argtys
where
goN :: Name -> String
goN :: Name -> String
goN Name
n = case (Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
filter Char -> Bool
isAlphaNum (Name -> String
forall a. Show a => a -> String
show Name
n) of [] -> String
"xx" ; String
s -> String
s
goT :: MonoType -> String
goT :: SimpleType 'Mono -> String
goT SimpleType 'Mono
DiscreteST = String
"i"
goT SimpleType 'Mono
ScalarST = String
"s"
goT (ConST Name
n [SimpleType 'Mono]
ts) = Name -> String
goN Name
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ (SimpleType 'Mono -> String) -> [SimpleType 'Mono] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SimpleType 'Mono -> String
goT [SimpleType 'Mono]
ts
class NumOperation a where
type DualNum a = r | r -> a
applyBinaryOp
:: DualNum a -> DualNum a
-> (a -> a -> a)
-> (a -> a -> (a, a))
-> FwdM (DualNum a)
applyUnaryOp
:: DualNum a
-> (a -> a)
-> (a -> a)
-> FwdM (DualNum a)
applyUnaryOp2
:: DualNum a
-> (a -> a)
-> (a -> a -> a)
-> FwdM (DualNum a)
applyCmpOp
:: DualNum a -> DualNum a
-> (a -> a -> Bool)
-> Bool
fromIntegralOp
:: Integral b
=> b
-> FwdM (DualNum a)
instance NumOperation Double where
type DualNum Double = DN
applyBinaryOp :: DualNum Double
-> DualNum Double
-> (Double -> Double -> Double)
-> (Double -> Double -> (Double, Double))
-> FwdM (DualNum Double)
applyBinaryOp (DN Double
x NID
xi Contrib
xcb) (DN Double
y NID
yi Contrib
ycb) Double -> Double -> Double
primal Double -> Double -> (Double, Double)
grad = do
let (Double
dx, Double
dy) = Double -> Double -> (Double, Double)
grad Double
x Double
y
NID
i <- FwdM NID
fwdmGenId
DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN (Double -> Double -> Double
primal Double
x Double
y) NID
i (CEdge -> CEdge -> Contrib
C2 (NID -> Contrib -> Double -> CEdge
CEdge NID
xi Contrib
xcb Double
dx) (NID -> Contrib -> Double -> CEdge
CEdge NID
yi Contrib
ycb Double
dy)))
applyUnaryOp :: DualNum Double
-> (Double -> Double)
-> (Double -> Double)
-> FwdM (DualNum Double)
applyUnaryOp (DN Double
x NID
xi Contrib
xcb) Double -> Double
primal Double -> Double
grad = do
NID
i <- FwdM NID
fwdmGenId
DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN (Double -> Double
primal Double
x) NID
i (CEdge -> Contrib
C1 (NID -> Contrib -> Double -> CEdge
CEdge NID
xi Contrib
xcb (Double -> Double
grad Double
x))))
applyUnaryOp2 :: DualNum Double
-> (Double -> Double)
-> (Double -> Double -> Double)
-> FwdM (DualNum Double)
applyUnaryOp2 (DN Double
x NID
xi Contrib
xcb) Double -> Double
primal Double -> Double -> Double
grad = do
NID
i <- FwdM NID
fwdmGenId
let pr :: Double
pr = Double -> Double
primal Double
x
DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN Double
pr NID
i (CEdge -> Contrib
C1 (NID -> Contrib -> Double -> CEdge
CEdge NID
xi Contrib
xcb (Double -> Double -> Double
grad Double
x Double
pr))))
applyCmpOp :: DualNum Double
-> DualNum Double -> (Double -> Double -> Bool) -> Bool
applyCmpOp (DN Double
x NID
_ Contrib
_) (DN Double
y NID
_ Contrib
_) Double -> Double -> Bool
f = Double -> Double -> Bool
f Double
x Double
y
fromIntegralOp :: forall b. Integral b => b -> FwdM (DualNum Double)
fromIntegralOp b
x = do
NID
i <- FwdM NID
fwdmGenId
DN -> FwdM DN
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> NID -> Contrib -> DN
DN (b -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
x) NID
i Contrib
C0)
instance NumOperation Int where
type DualNum Int = Int
applyBinaryOp :: DualNum Int
-> DualNum Int
-> (Int -> Int -> Int)
-> (Int -> Int -> (Int, Int))
-> FwdM (DualNum Int)
applyBinaryOp DualNum Int
x DualNum Int
y Int -> Int -> Int
primal Int -> Int -> (Int, Int)
_ = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int -> Int
primal Int
DualNum Int
x Int
DualNum Int
y)
applyUnaryOp :: DualNum Int -> (Int -> Int) -> (Int -> Int) -> FwdM (DualNum Int)
applyUnaryOp DualNum Int
x Int -> Int
primal Int -> Int
_ = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int
primal Int
DualNum Int
x)
applyUnaryOp2 :: DualNum Int
-> (Int -> Int) -> (Int -> Int -> Int) -> FwdM (DualNum Int)
applyUnaryOp2 DualNum Int
x Int -> Int
primal Int -> Int -> Int
_ = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int
primal Int
DualNum Int
x)
applyCmpOp :: DualNum Int -> DualNum Int -> (Int -> Int -> Bool) -> Bool
applyCmpOp DualNum Int
x DualNum Int
y Int -> Int -> Bool
f = Int -> Int -> Bool
f Int
DualNum Int
x Int
DualNum Int
y
fromIntegralOp :: forall b. Integral b => b -> FwdM (DualNum Int)
fromIntegralOp b
x = Int -> FwdM Int
forall a. a -> FwdM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
x)
checkDatacon :: Name -> Q [Type]
checkDatacon :: Name -> Q [Type]
checkDatacon Name
name = do
Type
conty <- Name -> Q Type
reifyType Name
name
(Type
tycon, [Type]
tyargs, [Type]
fieldtys) <- case Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
conty of
Just (Type, [Type], [Type])
ty -> (Type, [Type], [Type]) -> Q (Type, [Type], [Type])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type, [Type], [Type])
ty
Maybe (Type, [Type], [Type])
Nothing -> String -> Q (Type, [Type], [Type])
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (Type, [Type], [Type]))
-> String -> Q (Type, [Type], [Type])
forall a b. (a -> b) -> a -> b
$ String
"Could not deduce root type from type of data constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Ppr a => a -> String
pprint Name
name
[Name]
tyvars <- case (Type -> Maybe Name) -> [Type] -> Maybe [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\case VarT Name
n -> Name -> Maybe Name
forall a. a -> Maybe a
Just Name
n
Type
_ -> Maybe Name
forall a. Maybe a
Nothing)
[Type]
tyargs of
Just [Name]
vars -> [Name] -> Q [Name]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Name]
vars
Maybe [Name]
Nothing -> String -> Q [Name]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Normal constructor has GADT properties?"
let appliedType :: Type
appliedType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
tycon ((Name -> Type) -> [Name] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
_ -> Name -> Type
ConT ''()) [Name]
tyvars)
(SimpleType 'Mono, DataTypes)
_ <- Type -> Q (SimpleType 'Mono, DataTypes)
exploreRecursiveType Type
appliedType
[Type] -> Q [Type]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
fieldtys
fromDataconType :: Type -> Maybe (Type, [Type], [Type])
fromDataconType :: Type -> Maybe (Type, [Type], [Type])
fromDataconType (ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
t) = Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
t
fromDataconType (Type
ArrowT `AppT` Type
ty `AppT` Type
t) =
(\(Type
n, [Type]
typarams, [Type]
tys) -> (Type
n, [Type]
typarams, Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
tys)) ((Type, [Type], [Type]) -> (Type, [Type], [Type]))
-> Maybe (Type, [Type], [Type]) -> Maybe (Type, [Type], [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
t
fromDataconType (Type
MulArrowT `AppT` PromotedT Name
multi `AppT` Type
ty `AppT` Type
t)
| Name
multi Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'One = (\(Type
n, [Type]
typarams, [Type]
tys) -> (Type
n, [Type]
typarams, Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
tys)) ((Type, [Type], [Type]) -> (Type, [Type], [Type]))
-> Maybe (Type, [Type], [Type]) -> Maybe (Type, [Type], [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type], [Type])
fromDataconType Type
t
| Bool
otherwise = Maybe (Type, [Type], [Type])
forall a. Maybe a
Nothing
fromDataconType Type
t = (\(Type
n, [Type]
typarams) -> (Type
n, [Type]
typarams, [])) ((Type, [Type]) -> (Type, [Type], [Type]))
-> Maybe (Type, [Type]) -> Maybe (Type, [Type], [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type])
extractTypeCon Type
t
extractTypeCon :: Type -> Maybe (Type, [Type])
(AppT Type
t Type
arg) = ([Type] -> [Type]) -> (Type, [Type]) -> (Type, [Type])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ([Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type
arg]) ((Type, [Type]) -> (Type, [Type]))
-> Maybe (Type, [Type]) -> Maybe (Type, [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe (Type, [Type])
extractTypeCon Type
t
extractTypeCon (ConT Name
n) = (Type, [Type]) -> Maybe (Type, [Type])
forall a. a -> Maybe a
Just (Name -> Type
ConT Name
n, [])
extractTypeCon Type
ListT = (Type, [Type]) -> Maybe (Type, [Type])
forall a. a -> Maybe a
Just (Name -> Type
ConT ''[], [])
extractTypeCon (TupleT Int
n) = (Type, [Type]) -> Maybe (Type, [Type])
forall a. a -> Maybe a
Just (Int -> Type
TupleT Int
n, [])
extractTypeCon Type
_ = Maybe (Type, [Type])
forall a. Maybe a
Nothing
unpackFunctionType :: Type -> ([Type], Type)
unpackFunctionType :: Type -> ([Type], Type)
unpackFunctionType (Type
ArrowT `AppT` Type
ty `AppT` Type
t) = ([Type] -> [Type]) -> ([Type], Type) -> ([Type], Type)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:) (Type -> ([Type], Type)
unpackFunctionType Type
t)
unpackFunctionType (Type
MulArrowT `AppT` PromotedT Name
multi `AppT` Type
ty `AppT` Type
t)
| Name
multi Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== 'Many = ([Type] -> [Type]) -> ([Type], Type) -> ([Type], Type)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Type
ty Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:) (Type -> ([Type], Type)
unpackFunctionType Type
t)
unpackFunctionType Type
t = ([], Type
t)
isDiscrete :: Type -> Bool
isDiscrete :: Type -> Bool
isDiscrete (ConT Name
n) = Name
n Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
discreteTypeNames
isDiscrete t :: Type
t@AppT{} =
let (Type
hd, [Type]
args) = Type -> (Type, [Type])
collectApps Type
t
in case Type
hd of
TupleT Int
n | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n -> (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
isDiscrete [Type]
args
Type
ListT | [Type
arg] <- [Type]
args -> Type -> Bool
isDiscrete Type
arg
Type
_ -> Bool
False
isDiscrete Type
_ = Bool
False
collectApps :: Type -> (Type, [Type])
collectApps :: Type -> (Type, [Type])
collectApps = \Type
t -> Type -> [Type] -> (Type, [Type])
go Type
t []
where
go :: Type -> [Type] -> (Type, [Type])
go (AppT Type
t1 Type
t2) [Type]
prefix = Type -> [Type] -> (Type, [Type])
go Type
t1 (Type
t2 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
prefix)
go Type
t [Type]
prefix = (Type
t, [Type]
prefix)
liftKleisliN :: Int -> TH.Exp -> Q TH.Exp
liftKleisliN :: Int -> Exp -> Q Exp
liftKleisliN Int
0 Exp
e = Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
liftKleisliN Int
n Exp
e = do
Name
name <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
[| \ $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
name) -> mpure $(Int -> Exp -> Q Exp
liftKleisliN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Exp
e Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
name)) |]
pair :: TH.Exp -> TH.Exp -> TH.Exp
pair :: Exp -> Exp -> Exp
pair Exp
e1 Exp
e2 = [Maybe Exp] -> Exp
TupE [Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e1, Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e2]
tvbName :: TyVarBndr () -> Name
tvbName :: TyVarBndr () -> Name
tvbName (PlainTV Name
n ()
_) = Name
n
tvbName (KindedTV Name
n ()
_ Type
_) = Name
n
mapUnionsWithKey :: (Foldable f, Ord k) => (k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey :: forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(k -> a -> a -> a) -> f (Map k a) -> Map k a
mapUnionsWithKey k -> a -> a -> a
f = (Map k a -> Map k a -> Map k a)
-> Map k a -> f (Map k a) -> Map k a
forall a b. (a -> b -> b) -> b -> f a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
forall k a.
Ord k =>
(k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWithKey k -> a -> a -> a
f) Map k a
forall a. Monoid a => a
mempty
kENABLE_EVLOG :: Bool
kENABLE_EVLOG :: Bool
kENABLE_EVLOG = Bool
False
evlog :: String -> IO ()
evlog :: String -> IO ()
evlog String
_ | Bool -> Bool
not Bool
kENABLE_EVLOG = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
evlog String
s = do
TimeSpec
clk <- Clock -> IO TimeSpec
getTime Clock
Monotonic
MVar Handle -> (Handle -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar Handle
evlogfile ((Handle -> IO ()) -> IO ()) -> (Handle -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Handle
f -> do
Handle -> ByteString -> IO ()
BS.hPut Handle
f (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> ByteString
BS8.pack (Double -> String
forall a. Show a => a -> String
show (Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TimeSpec -> Integer
toNanoSecs TimeSpec
clk) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1e9 :: Double) String -> ShowS
forall a. [a] -> [a] -> [a]
++ Char
' ' Char -> ShowS
forall a. a -> [a] -> [a]
: String
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n")
Handle -> IO ()
hFlush Handle
f
{-# NOINLINE evlogfile #-}
evlogfile :: MVar Handle
evlogfile :: MVar Handle
evlogfile = IO (MVar Handle) -> MVar Handle
forall a. IO a -> a
unsafePerformIO (IO (MVar Handle) -> MVar Handle)
-> IO (MVar Handle) -> MVar Handle
forall a b. (a -> b) -> a -> b
$ Handle -> IO (MVar Handle)
forall a. a -> IO (MVar a)
newMVar (Handle -> IO (MVar Handle)) -> IO Handle -> IO (MVar Handle)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> IOMode -> IO Handle
openFile String
"evlogfile.txt" IOMode
AppendMode