{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
module Language.Haskell.ReverseAD.TH.Translate where

import Control.Monad (zipWithM, (>=>))
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.Graph
import qualified Data.Map.Strict as Map
import Data.Maybe (catMaybes)
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH as TH
import Language.Haskell.ReverseAD.TH.Source as Source
import Data.Foldable (fold)


translate :: TH.Exp -> Q Source.Exp
translate :: Exp -> Q Exp
translate = \case
  VarE Name
name -> Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
EVar Name
name

  ConE Name
name -> Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
ECon Name
name

  LitE Lit
lit -> Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Lit -> Exp
ELit Lit
lit

  -- Handle ($) specially in case the program needs the special type inference
  -- (let's hope it does not)
  VarE Name
dollar `AppE` Exp
e1 `AppE` Exp
e2 | Name
dollar Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== '($) -> Exp -> Exp -> Exp
EApp (Exp -> Exp -> Exp) -> Q Exp -> Q (Exp -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q Exp
translate Exp
e1 Q (Exp -> Exp) -> Q Exp -> Q Exp
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> Q Exp
translate Exp
e2

  AppE Exp
e1 Exp
e2 -> Exp -> Exp -> Exp
EApp (Exp -> Exp -> Exp) -> Q Exp -> Q (Exp -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q Exp
translate Exp
e1 Q (Exp -> Exp) -> Q Exp -> Q Exp
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> Q Exp
translate Exp
e2

  InfixE (Just Exp
e1) Exp
e (Just Exp
e2) ->
    Exp -> Q Exp
translate (Exp
e Exp -> Exp -> Exp
`AppE` Exp
e1 Exp -> Exp -> Exp
`AppE` Exp
e2)

  e :: Exp
e@InfixE{} -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"Unsupported operator section: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Show a => a -> String
show Exp
e

  ParensE Exp
e -> Exp -> Q Exp
translate Exp
e

  LamE [] Exp
e -> Exp -> Q Exp
translate Exp
e
  LamE (Pat
pat : [Pat]
pats) Exp
e -> Pat -> Exp -> Exp
ELam Pat
pat (Exp -> Exp) -> Q Exp -> Q Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q Exp
translate ([Pat] -> Exp -> Exp
LamE [Pat]
pats Exp
e)

  TupE [Maybe Exp]
mes -> do
    -- compute argument and body expression for this tuple item
    let processArg :: Int -> Maybe Exp -> Q (Maybe Pat, Exp)
processArg Int
i Maybe Exp
Nothing = do
          Name
name <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Int
i :: Int))
          (Maybe Pat, Exp) -> Q (Maybe Pat, Exp)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Pat -> Maybe Pat
forall a. a -> Maybe a
Just (Name -> Pat
VarP Name
name), Name -> Exp
EVar Name
name)
        processArg Int
_ (Just Exp
e) = do
          Exp
e' <- Exp -> Q Exp
translate Exp
e
          (Maybe Pat, Exp) -> Q (Maybe Pat, Exp)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Pat
forall a. Maybe a
Nothing, Exp
e')

    ([Maybe Pat]
args, [Exp]
es') <- [(Maybe Pat, Exp)] -> ([Maybe Pat], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe Pat, Exp)] -> ([Maybe Pat], [Exp]))
-> Q [(Maybe Pat, Exp)] -> Q ([Maybe Pat], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> Maybe Exp -> Q (Maybe Pat, Exp))
-> [Int] -> [Maybe Exp] -> Q [(Maybe Pat, Exp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Maybe Exp -> Q (Maybe Pat, Exp)
processArg [Int
1..] [Maybe Exp]
mes
    Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ (Pat -> Exp -> Exp) -> Exp -> [Pat] -> Exp
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Pat -> Exp -> Exp
ELam ([Exp] -> Exp
ETup [Exp]
es') ([Maybe Pat] -> [Pat]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Pat]
args)

  CondE Exp
e1 Exp
e2 Exp
e3 -> Exp -> Exp -> Exp -> Exp
ECond (Exp -> Exp -> Exp -> Exp) -> Q Exp -> Q (Exp -> Exp -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q Exp
translate Exp
e1 Q (Exp -> Exp -> Exp) -> Q Exp -> Q (Exp -> Exp)
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> Q Exp
translate Exp
e2 Q (Exp -> Exp) -> Q Exp -> Q Exp
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> Q Exp
translate Exp
e3

  LetE [Dec]
decs Exp
body -> [DecGroup] -> Exp -> Exp
elet ([DecGroup] -> Exp -> Exp) -> Q [DecGroup] -> Q (Exp -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dec] -> Q [DecGroup]
transDecs [Dec]
decs Q (Exp -> Exp) -> Q Exp -> Q Exp
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> Q Exp
translate Exp
body

  CaseE Exp
expr [Match]
matches ->
    Exp -> [(Pat, Exp)] -> Exp
ECase (Exp -> [(Pat, Exp)] -> Exp) -> Q Exp -> Q ([(Pat, Exp)] -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q Exp
translate Exp
expr
          Q ([(Pat, Exp)] -> Exp) -> Q [(Pat, Exp)] -> Q Exp
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Match -> Q (Pat, Exp)) -> [Match] -> Q [(Pat, Exp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\case Match Pat
pat (NormalB Exp
rhs) [Dec]
wdecs -> do
                                [DecGroup]
decs' <- [Dec] -> Q [DecGroup]
transDecs [Dec]
wdecs
                                Exp
rhs' <- Exp -> Q Exp
translate Exp
rhs
                                (Pat, Exp) -> Q (Pat, Exp)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Pat
pat, [DecGroup] -> Exp -> Exp
elet [DecGroup]
decs' Exp
rhs')
                              Match Pat
_ GuardedB{} [Dec]
_ ->
                                String -> Maybe String -> Q (Pat, Exp)
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Guards" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show (Exp -> [Match] -> Exp
CaseE Exp
expr [Match]
matches))))
                       [Match]
matches

  ListE [Exp]
es -> [Exp] -> Exp
EList ([Exp] -> Exp) -> Q [Exp] -> Q Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> Q Exp) -> [Exp] -> Q [Exp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Exp -> Q Exp
translate [Exp]
es

  SigE Exp
e Type
ty -> Exp -> Type -> Exp
ESig (Exp -> Type -> Exp) -> Q Exp -> Q (Type -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q Exp
translate Exp
e Q (Type -> Exp) -> Q Type -> Q Exp
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty

  UnboundVarE Name
n -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"Free variable in reverseAD: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n

  LamCaseE [Match]
mats -> do
    Name
name <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"lcarg"
    Pat -> Exp -> Exp
ELam (Name -> Pat
VarP Name
name) (Exp -> Exp) -> Q Exp -> Q Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q Exp
translate (Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
name) [Match]
mats)

  -- Unsupported constructs
  e :: Exp
e@AppTypeE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Type applications" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@UInfixE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"UInfixE" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@UnboxedTupE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Unboxed tuples" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@UnboxedSumE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Unboxed sums" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@MultiIfE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Multi-way ifs" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@DoE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Do blocks" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@MDoE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"MDo blocks" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@CompE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"List comprehensions" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@ArithSeqE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Arithmetic sequences" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@RecConE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Records" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@RecUpdE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Records" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@StaticE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Cloud Haskell" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@LabelE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Overloaded labels" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@ImplicitParamVarE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Implicit parameters" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@GetFieldE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Records" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@ProjectionE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Records" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))
  e :: Exp
e@LamCasesE{} -> String -> Maybe String -> Q Exp
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Lambda cases" (String -> Maybe String
forall a. a -> Maybe a
Just (Exp -> String
forall a. Show a => a -> String
show Exp
e))

data DesugaredDec e
  = DDVar Name e     -- x = e
  | DDSig Name Type  -- x :: e
  deriving (Int -> DesugaredDec e -> String -> String
[DesugaredDec e] -> String -> String
DesugaredDec e -> String
(Int -> DesugaredDec e -> String -> String)
-> (DesugaredDec e -> String)
-> ([DesugaredDec e] -> String -> String)
-> Show (DesugaredDec e)
forall e. Show e => Int -> DesugaredDec e -> String -> String
forall e. Show e => [DesugaredDec e] -> String -> String
forall e. Show e => DesugaredDec e -> String
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: forall e. Show e => Int -> DesugaredDec e -> String -> String
showsPrec :: Int -> DesugaredDec e -> String -> String
$cshow :: forall e. Show e => DesugaredDec e -> String
show :: DesugaredDec e -> String
$cshowList :: forall e. Show e => [DesugaredDec e] -> String -> String
showList :: [DesugaredDec e] -> String -> String
Show, (forall a b. (a -> b) -> DesugaredDec a -> DesugaredDec b)
-> (forall a b. a -> DesugaredDec b -> DesugaredDec a)
-> Functor DesugaredDec
forall a b. a -> DesugaredDec b -> DesugaredDec a
forall a b. (a -> b) -> DesugaredDec a -> DesugaredDec b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> DesugaredDec a -> DesugaredDec b
fmap :: forall a b. (a -> b) -> DesugaredDec a -> DesugaredDec b
$c<$ :: forall a b. a -> DesugaredDec b -> DesugaredDec a
<$ :: forall a b. a -> DesugaredDec b -> DesugaredDec a
Functor, (forall m. Monoid m => DesugaredDec m -> m)
-> (forall m a. Monoid m => (a -> m) -> DesugaredDec a -> m)
-> (forall m a. Monoid m => (a -> m) -> DesugaredDec a -> m)
-> (forall a b. (a -> b -> b) -> b -> DesugaredDec a -> b)
-> (forall a b. (a -> b -> b) -> b -> DesugaredDec a -> b)
-> (forall b a. (b -> a -> b) -> b -> DesugaredDec a -> b)
-> (forall b a. (b -> a -> b) -> b -> DesugaredDec a -> b)
-> (forall a. (a -> a -> a) -> DesugaredDec a -> a)
-> (forall a. (a -> a -> a) -> DesugaredDec a -> a)
-> (forall a. DesugaredDec a -> [a])
-> (forall a. DesugaredDec a -> Bool)
-> (forall a. DesugaredDec a -> Int)
-> (forall a. Eq a => a -> DesugaredDec a -> Bool)
-> (forall a. Ord a => DesugaredDec a -> a)
-> (forall a. Ord a => DesugaredDec a -> a)
-> (forall a. Num a => DesugaredDec a -> a)
-> (forall a. Num a => DesugaredDec a -> a)
-> Foldable DesugaredDec
forall a. Eq a => a -> DesugaredDec a -> Bool
forall a. Num a => DesugaredDec a -> a
forall a. Ord a => DesugaredDec a -> a
forall m. Monoid m => DesugaredDec m -> m
forall a. DesugaredDec a -> Bool
forall a. DesugaredDec a -> Int
forall a. DesugaredDec a -> [a]
forall a. (a -> a -> a) -> DesugaredDec a -> a
forall m a. Monoid m => (a -> m) -> DesugaredDec a -> m
forall b a. (b -> a -> b) -> b -> DesugaredDec a -> b
forall a b. (a -> b -> b) -> b -> DesugaredDec a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall m. Monoid m => DesugaredDec m -> m
fold :: forall m. Monoid m => DesugaredDec m -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> DesugaredDec a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> DesugaredDec a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> DesugaredDec a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> DesugaredDec a -> m
$cfoldr :: forall a b. (a -> b -> b) -> b -> DesugaredDec a -> b
foldr :: forall a b. (a -> b -> b) -> b -> DesugaredDec a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> DesugaredDec a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> DesugaredDec a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> DesugaredDec a -> b
foldl :: forall b a. (b -> a -> b) -> b -> DesugaredDec a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> DesugaredDec a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> DesugaredDec a -> b
$cfoldr1 :: forall a. (a -> a -> a) -> DesugaredDec a -> a
foldr1 :: forall a. (a -> a -> a) -> DesugaredDec a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> DesugaredDec a -> a
foldl1 :: forall a. (a -> a -> a) -> DesugaredDec a -> a
$ctoList :: forall a. DesugaredDec a -> [a]
toList :: forall a. DesugaredDec a -> [a]
$cnull :: forall a. DesugaredDec a -> Bool
null :: forall a. DesugaredDec a -> Bool
$clength :: forall a. DesugaredDec a -> Int
length :: forall a. DesugaredDec a -> Int
$celem :: forall a. Eq a => a -> DesugaredDec a -> Bool
elem :: forall a. Eq a => a -> DesugaredDec a -> Bool
$cmaximum :: forall a. Ord a => DesugaredDec a -> a
maximum :: forall a. Ord a => DesugaredDec a -> a
$cminimum :: forall a. Ord a => DesugaredDec a -> a
minimum :: forall a. Ord a => DesugaredDec a -> a
$csum :: forall a. Num a => DesugaredDec a -> a
sum :: forall a. Num a => DesugaredDec a -> a
$cproduct :: forall a. Num a => DesugaredDec a -> a
product :: forall a. Num a => DesugaredDec a -> a
Foldable, Functor DesugaredDec
Foldable DesugaredDec
(Functor DesugaredDec, Foldable DesugaredDec) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> DesugaredDec a -> f (DesugaredDec b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    DesugaredDec (f a) -> f (DesugaredDec a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> DesugaredDec a -> m (DesugaredDec b))
-> (forall (m :: * -> *) a.
    Monad m =>
    DesugaredDec (m a) -> m (DesugaredDec a))
-> Traversable DesugaredDec
forall (t :: * -> *).
(Functor t, Foldable t) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a.
Monad m =>
DesugaredDec (m a) -> m (DesugaredDec a)
forall (f :: * -> *) a.
Applicative f =>
DesugaredDec (f a) -> f (DesugaredDec a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DesugaredDec a -> m (DesugaredDec b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DesugaredDec a -> f (DesugaredDec b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DesugaredDec a -> f (DesugaredDec b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DesugaredDec a -> f (DesugaredDec b)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
DesugaredDec (f a) -> f (DesugaredDec a)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
DesugaredDec (f a) -> f (DesugaredDec a)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DesugaredDec a -> m (DesugaredDec b)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DesugaredDec a -> m (DesugaredDec b)
$csequence :: forall (m :: * -> *) a.
Monad m =>
DesugaredDec (m a) -> m (DesugaredDec a)
sequence :: forall (m :: * -> *) a.
Monad m =>
DesugaredDec (m a) -> m (DesugaredDec a)
Traversable)

-- Convert function declarations to simple variable declarations:
--   f a b c = E
--   f d e f = F
-- becomes
--   f = \arg1 arg2 arg3 -> case (arg1, arg2, arg3) of
--                            (a, b, c) -> E
--                            (d, e, f) -> F
--
-- Furthermore, pattern bindings are converted to go via a tuple:
--   (a, Right (b, c)) = E
-- becomes
--   vartup = case E of (a, Right (b, c)) -> (a, b, c)
--   a = case vartup of (x, _, _) -> x
--   b = case vartup of (_, x, _) -> x
--   c = case vartup of (_, _, x) -> x
--
-- SigD, i.e. type signatures, are passed through unchanged.
desugarDec :: (Quote m, MonadFail m) => Dec -> m [DesugaredDec TH.Exp]
desugarDec :: forall (m :: * -> *).
(Quote m, MonadFail m) =>
Dec -> m [DesugaredDec Exp]
desugarDec = \case
  ValD (VarP Name
var) (NormalB Exp
rhs) [Dec]
wdecs ->
    [DesugaredDec Exp] -> m [DesugaredDec Exp]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [Name -> Exp -> DesugaredDec Exp
forall e. Name -> e -> DesugaredDec e
DDVar Name
var ([Dec] -> Exp -> Exp
letE' [Dec]
wdecs Exp
rhs)]

  ValD Pat
pat (NormalB Exp
rhs) [Dec]
wdecs -> do
    Name
tupname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"vartup"
    Name
xname <- String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
    [Name]
vars <- Set Name -> [Name]
forall a. Set a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Set Name -> [Name]) -> m (Set Name) -> m [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
    let nvars :: Int
nvars = [Name] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
vars
    [DesugaredDec Exp] -> m [DesugaredDec Exp]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([DesugaredDec Exp] -> m [DesugaredDec Exp])
-> [DesugaredDec Exp] -> m [DesugaredDec Exp]
forall a b. (a -> b) -> a -> b
$
      Name -> Exp -> DesugaredDec Exp
forall e. Name -> e -> DesugaredDec e
DDVar Name
tupname
            (Exp -> [Match] -> Exp
CaseE ([Dec] -> Exp -> Exp
letE' [Dec]
wdecs Exp
rhs)
               [Pat -> Body -> [Dec] -> Match
Match Pat
pat (Exp -> Body
NormalB ([Maybe Exp] -> Exp
TupE ((Name -> Maybe Exp) -> [Name] -> [Maybe Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> (Name -> Exp) -> Name -> Maybe Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE) [Name]
vars))) []])
      DesugaredDec Exp -> [DesugaredDec Exp] -> [DesugaredDec Exp]
forall a. a -> [a] -> [a]
: [Name -> Exp -> DesugaredDec Exp
forall e. Name -> e -> DesugaredDec e
DDVar Name
var
               (Exp -> [Match] -> Exp
CaseE (Name -> Exp
VarE Name
tupname)
                  [Pat -> Body -> [Dec] -> Match
Match ([Pat] -> Pat
TupP (Int -> Pat -> [Pat]
forall a. Int -> a -> [a]
replicate Int
i Pat
WildP [Pat] -> [Pat] -> [Pat]
forall a. [a] -> [a] -> [a]
++ [Name -> Pat
VarP Name
xname] [Pat] -> [Pat] -> [Pat]
forall a. [a] -> [a] -> [a]
++ Int -> Pat -> [Pat]
forall a. Int -> a -> [a]
replicate (Int
nvars Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) Pat
WildP))
                         (Exp -> Body
NormalB (Name -> Exp
VarE Name
xname))
                         []])
        | (Int
i, Name
var) <- [Int] -> [Name] -> [(Int, Name)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Name]
vars]

  FunD Name
_ [] -> String -> m [DesugaredDec Exp]
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Function declaration with empty list of clauses?"

  FunD Name
name clauses :: [Clause]
clauses@(Clause
_:[Clause]
_) ->
    case (Clause -> Either String ([Pat], Exp))
-> [Clause] -> Either String [([Pat], Exp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Clause -> Either String ([Pat], Exp)
fromSimpleClause [Clause]
clauses of
      Left String
err -> String -> m [DesugaredDec Exp]
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
err
      Right [([Pat], Exp)]
cpairs
        | [Int] -> Bool
forall a. Eq a => [a] -> Bool
allEqual [[Pat] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Pat]
pats | ([Pat]
pats, Exp
_) <- [([Pat], Exp)]
cpairs] -> do
            let nargs :: Int
nargs = [Int] -> Int
forall a. HasCallStack => [a] -> a
head [[Pat] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Pat]
pats | Clause [Pat]
pats Body
_ [Dec]
_ <- [Clause]
clauses]
            [Name]
argnames <- (Int -> m Name) -> [Int] -> m [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> m Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"arg" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..Int
nargs]
            let body :: Exp
body = [Pat] -> Exp -> Exp
LamE ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
argnames) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                         Exp -> [Match] -> Exp
CaseE ([Maybe Exp] -> Exp
TupE ((Name -> Maybe Exp) -> [Name] -> [Maybe Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> (Name -> Exp) -> Name -> Maybe Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE) [Name]
argnames))
                           [Pat -> Body -> [Dec] -> Match
Match ([Pat] -> Pat
TupP [Pat]
ps) (Exp -> Body
NormalB Exp
rhs) []
                           | ([Pat]
ps, Exp
rhs) <- [([Pat], Exp)]
cpairs]
            [DesugaredDec Exp] -> m [DesugaredDec Exp]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [Name -> Exp -> DesugaredDec Exp
forall e. Name -> e -> DesugaredDec e
DDVar Name
name Exp
body]
        | Bool
otherwise ->
            String -> m [DesugaredDec Exp]
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m [DesugaredDec Exp]) -> String -> m [DesugaredDec Exp]
forall a b. (a -> b) -> a -> b
$ String
"Clauses of declaration of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
name String -> String -> String
forall a. [a] -> [a] -> [a]
++
                   String
" do not all have the same number of arguments"
    where
      fromSimpleClause :: Clause -> Either String ([Pat], Exp)
fromSimpleClause (Clause [Pat]
pats (NormalB Exp
body) []) = ([Pat], Exp) -> Either String ([Pat], Exp)
forall a b. b -> Either a b
Right ([Pat]
pats, Exp
body)
      fromSimpleClause (Clause [Pat]
pats (NormalB Exp
body) [Dec]
wdecs) =
        Clause -> Either String ([Pat], Exp)
fromSimpleClause ([Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
pats (Exp -> Body
NormalB ([Dec] -> Exp -> Exp
letE' [Dec]
wdecs Exp
body)) [])
      fromSimpleClause (Clause [Pat]
_ GuardedB{} [Dec]
_) =
        String -> Either String ([Pat], Exp)
forall a b. a -> Either a b
Left (String -> Either String ([Pat], Exp))
-> String -> Either String ([Pat], Exp)
forall a b. (a -> b) -> a -> b
$ String
"Guards not supported in declaration of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
name

  SigD Name
name Type
typ -> [DesugaredDec Exp] -> m [DesugaredDec Exp]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [Name -> Type -> DesugaredDec Exp
forall e. Name -> Type -> DesugaredDec e
DDSig Name
name Type
typ]

  Dec
dec -> String -> m [DesugaredDec Exp]
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m [DesugaredDec Exp]) -> String -> m [DesugaredDec Exp]
forall a b. (a -> b) -> a -> b
$ String
"Only simple declarations supported in reverseAD: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Dec -> String
forall a. Show a => a -> String
show Dec
dec
  where
    allEqual :: Eq a => [a] -> Bool
    allEqual :: forall a. Eq a => [a] -> Bool
allEqual [] = Bool
True
    allEqual (a
x:[a]
xs) = (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x) [a]
xs

-- | Assumes the declarations occur in a let block. Checks that the
-- non-function bindings are non-recursive.
-- Returns a wrapper that defines all of the names, the list of defined names,
-- and the set of all free variables of the collective let-block.
groupDecs :: [DesugaredDec Source.Exp] -> Q [DecGroup]
groupDecs :: [DesugaredDec Exp] -> Q [DecGroup]
groupDecs [DesugaredDec Exp]
decs = do
  let ([(Name, Exp)]
bindings, [(Name, Type)]
signatures) = [Either (Name, Exp) (Name, Type)]
-> ([(Name, Exp)], [(Name, Type)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Name, Exp) (Name, Type)]
 -> ([(Name, Exp)], [(Name, Type)]))
-> [Either (Name, Exp) (Name, Type)]
-> ([(Name, Exp)], [(Name, Type)])
forall a b. (a -> b) -> a -> b
$ ((DesugaredDec Exp -> Either (Name, Exp) (Name, Type))
 -> [DesugaredDec Exp] -> [Either (Name, Exp) (Name, Type)])
-> [DesugaredDec Exp]
-> (DesugaredDec Exp -> Either (Name, Exp) (Name, Type))
-> [Either (Name, Exp) (Name, Type)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (DesugaredDec Exp -> Either (Name, Exp) (Name, Type))
-> [DesugaredDec Exp] -> [Either (Name, Exp) (Name, Type)]
forall a b. (a -> b) -> [a] -> [b]
map [DesugaredDec Exp]
decs ((DesugaredDec Exp -> Either (Name, Exp) (Name, Type))
 -> [Either (Name, Exp) (Name, Type)])
-> (DesugaredDec Exp -> Either (Name, Exp) (Name, Type))
-> [Either (Name, Exp) (Name, Type)]
forall a b. (a -> b) -> a -> b
$ \case
        DDVar Name
name Exp
e -> (Name, Exp) -> Either (Name, Exp) (Name, Type)
forall a b. a -> Either a b
Left (Name
name, Exp
e)
        DDSig Name
name Type
ty -> (Name, Type) -> Either (Name, Exp) (Name, Type)
forall a b. b -> Either a b
Right (Name
name, Type
ty)

  let boundNames :: Set Name
boundNames = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList (((Name, Exp) -> Name) -> [(Name, Exp)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Exp) -> Name
forall a b. (a, b) -> a
fst [(Name, Exp)]
bindings)

  let signatureMap :: Map Name Type
signatureMap = [(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)]
signatures

  let handleComp :: SCC (Name, Source.Exp) -> Q DecGroup
      handleComp :: SCC (Name, Exp) -> Q DecGroup
handleComp (AcyclicSCC (Name
name, Exp
e)) = DecGroup -> Q DecGroup
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Maybe Type -> Exp -> DecGroup
DecVar Name
name (Name -> Map Name Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
name Map Name Type
signatureMap) Exp
e)
      handleComp (CyclicSCC [(Name, Exp)]
pairs)
        | Just [(Name, Maybe Type, Pat, Exp)]
res <- ((Name, Exp) -> Maybe (Name, Maybe Type, Pat, Exp))
-> [(Name, Exp)] -> Maybe [(Name, Maybe Type, Pat, Exp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\case (Name
name, ELam Pat
p Exp
e) ->
                                        (Name, Maybe Type, Pat, Exp) -> Maybe (Name, Maybe Type, Pat, Exp)
forall a. a -> Maybe a
Just (Name
name, Name -> Map Name Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
name Map Name Type
signatureMap, Pat
p, Exp
e)
                                      (Name, Exp)
_ -> Maybe (Name, Maybe Type, Pat, Exp)
forall a. Maybe a
Nothing)
                               [(Name, Exp)]
pairs =
            DecGroup -> Q DecGroup
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Name, Maybe Type, Pat, Exp)] -> DecGroup
DecMutGroup [(Name, Maybe Type, Pat, Exp)]
res)
        | Bool
otherwise =
            String -> Maybe String -> Q DecGroup
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Recursive non-function bindings" (String -> Maybe String
forall a. a -> Maybe a
Just ([(Name, Exp)] -> String
forall a. Show a => a -> String
show [(Name, Exp)]
pairs))

  [((Name, Exp), Name, [Name])]
tups <- [[((Name, Exp), Name, [Name])]] -> [((Name, Exp), Name, [Name])]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[((Name, Exp), Name, [Name])]] -> [((Name, Exp), Name, [Name])])
-> Q [[((Name, Exp), Name, [Name])]]
-> Q [((Name, Exp), Name, [Name])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Exp) -> Q [((Name, Exp), Name, [Name])])
-> [(Name, Exp)] -> Q [[((Name, Exp), Name, [Name])]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Name
name, Exp
e) -> do
                              Set Name
frees <- Exp -> Q (Set Name)
freeVars Exp
e
                              [((Name, Exp), Name, [Name])] -> Q [((Name, Exp), Name, [Name])]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [((Name
name, Exp
e), Name
name, Set Name -> [Name]
forall a. Set a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Set Name
frees Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
`Set.intersection` Set Name
boundNames))])
                          [(Name, Exp)]
bindings
  let sccs :: [SCC (Name, Exp)]
sccs = [((Name, Exp), Name, [Name])] -> [SCC (Name, Exp)]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp [((Name, Exp), Name, [Name])]
tups  -- [(node, key, [key])] -> [SCC node]
  (SCC (Name, Exp) -> Q DecGroup)
-> [SCC (Name, Exp)] -> Q [DecGroup]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse SCC (Name, Exp) -> Q DecGroup
handleComp [SCC (Name, Exp)]
sccs

transDecs :: [Dec] -> Q [DecGroup]
transDecs :: [Dec] -> Q [DecGroup]
transDecs = (Dec -> Q [DesugaredDec Exp]) -> [Dec] -> Q [DesugaredDec Exp]
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM Dec -> Q [DesugaredDec Exp]
forall (m :: * -> *).
(Quote m, MonadFail m) =>
Dec -> m [DesugaredDec Exp]
desugarDec ([Dec] -> Q [DesugaredDec Exp])
-> ([DesugaredDec Exp] -> Q [DecGroup]) -> [Dec] -> Q [DecGroup]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (DesugaredDec Exp -> Q (DesugaredDec Exp))
-> [DesugaredDec Exp] -> Q [DesugaredDec Exp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Exp -> Q Exp) -> DesugaredDec Exp -> Q (DesugaredDec Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DesugaredDec a -> f (DesugaredDec b)
traverse Exp -> Q Exp
translate) ([DesugaredDec Exp] -> Q [DesugaredDec Exp])
-> ([DesugaredDec Exp] -> Q [DecGroup])
-> [DesugaredDec Exp]
-> Q [DecGroup]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> [DesugaredDec Exp] -> Q [DecGroup]
groupDecs

freeVars :: Source.Exp -> Q (Set Name)
freeVars :: Exp -> Q (Set Name)
freeVars = \case
  EVar Name
n -> Set Name -> Q (Set Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Set Name
forall a. a -> Set a
Set.singleton Name
n)
  ECon{} -> Set Name -> Q (Set Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Set Name
forall a. Monoid a => a
mempty
  ELit{} -> Set Name -> Q (Set Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Set Name
forall a. Monoid a => a
mempty
  EApp Exp
e1 Exp
e2 -> Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
(<>) (Set Name -> Set Name -> Set Name)
-> Q (Set Name) -> Q (Set Name -> Set Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q (Set Name)
freeVars Exp
e1 Q (Set Name -> Set Name) -> Q (Set Name) -> Q (Set Name)
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> Q (Set Name)
freeVars Exp
e2
  ELam Pat
pat Exp
e -> do
    Set Name
bound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
    Set Name
frees <- Exp -> Q (Set Name)
freeVars Exp
e
    Set Name -> Q (Set Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Set Name
frees Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.\\ Set Name
bound)
  ETup [Exp]
es -> (Exp -> Q (Set Name)) -> [Exp] -> Q (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM Exp -> Q (Set Name)
freeVars [Exp]
es
  ECond Exp
e1 Exp
e2 Exp
e3 -> (Exp -> Q (Set Name)) -> [Exp] -> Q (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM Exp -> Q (Set Name)
freeVars [Exp
e1, Exp
e2, Exp
e3]
  ELet [DecGroup]
decgs Exp
body -> do
    (Set Name
bounds, Set Name
frees) <- (DecGroup -> Q (Set Name, Set Name))
-> [DecGroup] -> Q (Set Name, Set Name)
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DecGroup -> Q (Set Name, Set Name)
go [DecGroup]
decgs
    Set Name
bfrees <- Exp -> Q (Set Name)
freeVars Exp
body
    Set Name -> Q (Set Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Set Name
frees Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
<> (Set Name
bfrees Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.\\ Set Name
bounds))
    where go :: DecGroup -> Q (Set Name, Set Name)
          go :: DecGroup -> Q (Set Name, Set Name)
go (DecVar Name
n Maybe Type
_ Exp
e) = (Name -> Set Name
forall a. a -> Set a
Set.singleton Name
n,) (Set Name -> (Set Name, Set Name))
-> Q (Set Name) -> Q (Set Name, Set Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q (Set Name)
freeVars Exp
e
          go (DecMutGroup [(Name, Maybe Type, Pat, Exp)]
ds) = ((Name, Maybe Type, Pat, Exp) -> Q (Set Name, Set Name))
-> [(Name, Maybe Type, Pat, Exp)] -> Q (Set Name, Set Name)
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\(Name
n, Maybe Type
_, Pat
p, Exp
e) -> Name -> Pat -> Exp -> Q (Set Name, Set Name)
goMG Name
n Pat
p Exp
e) [(Name, Maybe Type, Pat, Exp)]
ds
            where
              goMG :: Name -> Pat -> Source.Exp -> Q (Set Name, Set Name)
              goMG :: Name -> Pat -> Exp -> Q (Set Name, Set Name)
goMG Name
n Pat
p Exp
e = do
                Set Name
bounds <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p
                Set Name
frees <- Exp -> Q (Set Name)
freeVars Exp
e
                (Set Name, Set Name) -> Q (Set Name, Set Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Set Name
forall a. a -> Set a
Set.singleton Name
n, Set Name
frees Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.\\ Set Name
bounds)
  ECase Exp
e [(Pat, Exp)]
ms -> Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
(<>) (Set Name -> Set Name -> Set Name)
-> Q (Set Name) -> Q (Set Name -> Set Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Q (Set Name)
freeVars Exp
e Q (Set Name -> Set Name) -> Q (Set Name) -> Q (Set Name)
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Pat, Exp) -> Q (Set Name)) -> [(Pat, Exp)] -> Q (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM (Pat, Exp) -> Q (Set Name)
go [(Pat, Exp)]
ms
    where go :: (Pat, Source.Exp) -> Q (Set Name)
          go :: (Pat, Exp) -> Q (Set Name)
go (Pat
pat, Exp
rhs) = do
            Set Name
bound <- Pat -> Q (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
pat
            Set Name
frees <- Exp -> Q (Set Name)
freeVars Exp
rhs
            Set Name -> Q (Set Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Set Name
frees Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.\\ Set Name
bound)
  EList [Exp]
es -> (Exp -> Q (Set Name)) -> [Exp] -> Q (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM (Exp -> Q (Set Name)
freeVars) [Exp]
es
  ESig Exp
e Type
_ -> Exp -> Q (Set Name)
freeVars Exp
e

boundVars :: MonadFail m => Pat -> m (Set Name)
boundVars :: forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars = \case
  LitP Lit
_ -> Set Name -> m (Set Name)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Set Name
forall a. Monoid a => a
mempty
  VarP Name
n -> Set Name -> m (Set Name)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Set Name
forall a. a -> Set a
Set.singleton Name
n)
  TupP [Pat]
ps -> (Pat -> m (Set Name)) -> [Pat] -> m (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars [Pat]
ps
  UnboxedTupP [Pat]
ps -> (Pat -> m (Set Name)) -> [Pat] -> m (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars [Pat]
ps
  ConP Name
_ [Type]
_ [Pat]
ps -> (Pat -> m (Set Name)) -> [Pat] -> m (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars [Pat]
ps
  InfixP Pat
p1 Name
_ Pat
p2 -> Set Name -> Set Name -> Set Name
forall a. Semigroup a => a -> a -> a
(<>) (Set Name -> Set Name -> Set Name)
-> m (Set Name) -> m (Set Name -> Set Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p1 m (Set Name -> Set Name) -> m (Set Name) -> m (Set Name)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p2
  ParensP Pat
p -> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p
  TildeP Pat
p -> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p
  BangP Pat
p -> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p
  AsP Name
n Pat
p -> Name -> Set Name -> Set Name
forall a. Ord a => a -> Set a -> Set a
Set.insert Name
n (Set Name -> Set Name) -> m (Set Name) -> m (Set Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p
  Pat
WildP -> Set Name -> m (Set Name)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Set Name
forall a. Monoid a => a
mempty
  ListP [Pat]
ps -> (Pat -> m (Set Name)) -> [Pat] -> m (Set Name)
forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars [Pat]
ps
  SigP Pat
p Type
_ -> Pat -> m (Set Name)
forall (m :: * -> *). MonadFail m => Pat -> m (Set Name)
boundVars Pat
p
  p :: Pat
p@UnboxedSumP{} -> String -> Maybe String -> m (Set Name)
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Unboxed sums" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  p :: Pat
p@UInfixP{} -> String -> Maybe String -> m (Set Name)
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"UInfixP" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  p :: Pat
p@RecP{} -> String -> Maybe String -> m (Set Name)
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"Records" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))
  p :: Pat
p@ViewP{} -> String -> Maybe String -> m (Set Name)
forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
"View patterns" (String -> Maybe String
forall a. a -> Maybe a
Just (Pat -> String
forall a. Show a => a -> String
show Pat
p))

notSupported :: MonadFail m => String -> Maybe String -> m a
notSupported :: forall (m :: * -> *) a.
MonadFail m =>
String -> Maybe String -> m a
notSupported String
descr Maybe String
mthing = String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m a) -> String -> m a
forall a b. (a -> b) -> a -> b
$ String
descr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not supported in reverseAD" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++) Maybe String
mthing

-- | Constructs a 'LetE', but returns the rhs untouched if the list is empty
-- instead of creating an empty let block.
letE' :: [Dec] -> TH.Exp -> TH.Exp
letE' :: [Dec] -> Exp -> Exp
letE' [] Exp
rhs = Exp
rhs
letE' [Dec]
ds Exp
rhs = [Dec] -> Exp -> Exp
LetE [Dec]
ds Exp
rhs

elet :: [DecGroup] -> Source.Exp -> Source.Exp
elet :: [DecGroup] -> Exp -> Exp
elet [] Exp
rhs = Exp
rhs
elet [DecGroup]
ds Exp
rhs = [DecGroup] -> Exp -> Exp
ELet [DecGroup]
ds Exp
rhs

concatMapM :: (Applicative f, Monoid m, Traversable t) => (a -> f m) -> t a -> f m
concatMapM :: forall (f :: * -> *) m (t :: * -> *) a.
(Applicative f, Monoid m, Traversable t) =>
(a -> f m) -> t a -> f m
concatMapM a -> f m
f = (t m -> m) -> f (t m) -> f m
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap t m -> m
forall m. Monoid m => t m -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold (f (t m) -> f m) -> (t a -> f (t m)) -> t a -> f m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> f m) -> t a -> f (t m)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse a -> f m
f