{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Control.Concurrent.ThreadPool (
  -- * Pools
  Pool,
  mkPool,
  mkPoolN,
  globalThreadPool,
  scalePool,

  -- * Running jobs
  submitJob,

  -- * Debug
  debug,
) where

import Control.Concurrent
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Data.IORef
import qualified Data.Vector as V
import Numeric
import System.Clock
import System.IO
import System.IO.Unsafe


foreign import ccall "setup_rts_gc_took_tom" c_setup_rts_gc_took_tom :: IO ()


ts2float :: TimeSpec -> Double
ts2float :: TimeSpec -> Double
ts2float TimeSpec
t = Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TimeSpec -> Integer
toNanoSecs (TimeSpec -> TimeSpec -> TimeSpec
diffTimeSpec TimeSpec
epoch TimeSpec
t)) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1e6 :: Double

tsdiff2float :: TimeSpec -> Double
tsdiff2float :: TimeSpec -> Double
tsdiff2float TimeSpec
t = Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TimeSpec -> Integer
toNanoSecs TimeSpec
t) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1e6 :: Double


kENABLE_DEBUG :: Bool
kENABLE_DEBUG :: Bool
kENABLE_DEBUG = Bool
False

{-# NOINLINE epoch #-}
epoch :: TimeSpec
epoch :: TimeSpec
epoch = IO TimeSpec -> TimeSpec
forall a. IO a -> a
unsafePerformIO (IO TimeSpec -> TimeSpec) -> IO TimeSpec -> TimeSpec
forall a b. (a -> b) -> a -> b
$ Clock -> IO TimeSpec
getTime Clock
Monotonic

{-# NOINLINE debugLock #-}
debugLock :: MVar ()
debugLock :: MVar ()
debugLock = IO (MVar ()) -> MVar ()
forall a. IO a -> a
unsafePerformIO (IO (MVar ()) -> MVar ()) -> IO (MVar ()) -> MVar ()
forall a b. (a -> b) -> a -> b
$ () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar ()

debug :: String -> IO ()
debug :: String -> IO ()
debug String
_ | Bool -> Bool
not Bool
kENABLE_DEBUG = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
debug String
s = do
  TimeSpec
t1 <- Clock -> IO TimeSpec
getTime Clock
Monotonic
  String
_ <- String -> IO String
forall a. a -> IO a
evaluate (String -> String
forall a. NFData a => a -> a
force String
s)
  TimeSpec
t2 <- Clock -> IO TimeSpec
getTime Clock
Monotonic
  ThreadId
me <- IO ThreadId
myThreadId
  MVar () -> (() -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar ()
debugLock ((() -> IO ()) -> IO ()) -> (() -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \() ->
    Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"@" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Maybe Int -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> String -> String
showFFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) (TimeSpec -> Double
ts2float TimeSpec
t1) String
" @" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Maybe Int -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> String -> String
showFFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) (TimeSpec -> Double
ts2float TimeSpec
t2) String
"" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" [" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ThreadId -> String
forall a. Show a => a -> String
show ThreadId
me String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"] " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s


-- | A thread pool.
data Pool = Pool (Chan Job) (MVar (V.Vector Worker)) (IORef Int)

newtype Worker = Worker ThreadId

data Job = Job !Int !(IO ())

{-# NOINLINE globalThreadPool #-}
-- | A statically allocated thread pool.
globalThreadPool :: Pool
globalThreadPool :: Pool
globalThreadPool = IO Pool -> Pool
forall a. IO a -> a
unsafePerformIO (IO Pool -> Pool) -> IO Pool -> Pool
forall a b. (a -> b) -> a -> b
$ do
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
kENABLE_DEBUG IO ()
c_setup_rts_gc_took_tom
  IO Pool
mkPool

-- | Create a new thread pool with one worker for every capability (see
-- 'getNumCapabilities').
mkPool :: IO Pool
mkPool :: IO Pool
mkPool = IO Int
getNumCapabilities IO Int -> (Int -> IO Pool) -> IO Pool
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> IO Pool
mkPoolN

-- | Create a new thread pool with the given number of worker threads.
mkPoolN :: Int -> IO Pool
mkPoolN :: Int -> IO Pool
mkPoolN Int
n = do
  Chan Job
chan <- IO (Chan Job)
forall a. IO (Chan a)
newChan
  [Worker]
workers <- [Int] -> (Int -> IO Worker) -> IO [Worker]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] (Chan Job -> Int -> IO Worker
startWorker Chan Job
chan)
  MVar (Vector Worker)
listref <- Vector Worker -> IO (MVar (Vector Worker))
forall a. a -> IO (MVar a)
newMVar (Int -> [Worker] -> Vector Worker
forall a. Int -> [a] -> Vector a
V.fromListN Int
n [Worker]
workers)
  IORef Int
jidref <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0
  Pool -> IO Pool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Chan Job -> MVar (Vector Worker) -> IORef Int -> Pool
Pool Chan Job
chan MVar (Vector Worker)
listref IORef Int
jidref)

startWorker :: Chan Job -> Int -> IO Worker
startWorker :: Chan Job -> Int -> IO Worker
startWorker Chan Job
chan Int
i = ThreadId -> Worker
Worker (ThreadId -> Worker) -> IO ThreadId -> IO Worker
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO () -> IO ThreadId
forkOn Int
i IO ()
loop
  where
    loop :: IO ()
loop = do
      String -> IO ()
debug String
"waiting"
      -- When the pool is dropped, the channels are also dropped, and this
      -- readChan will block indefinitely, raising the exception, which makes
      -- the worker exit and all is good.
      Maybe Job
mjob <- IO (Maybe Job)
-> (BlockedIndefinitelyOnMVar -> IO (Maybe Job)) -> IO (Maybe Job)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (Job -> Maybe Job
forall a. a -> Maybe a
Just (Job -> Maybe Job) -> IO Job -> IO (Maybe Job)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Chan Job -> IO Job
forall a. Chan a -> IO a
readChan Chan Job
chan) ((BlockedIndefinitelyOnMVar -> IO (Maybe Job)) -> IO (Maybe Job))
-> (BlockedIndefinitelyOnMVar -> IO (Maybe Job)) -> IO (Maybe Job)
forall a b. (a -> b) -> a -> b
$ \(BlockedIndefinitelyOnMVar
_ :: BlockedIndefinitelyOnMVar) -> Maybe Job -> IO (Maybe Job)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Job
forall a. Maybe a
Nothing
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ @Maybe Maybe Job
mjob ((Job -> IO ()) -> IO ()) -> (Job -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Job Int
jobid IO ()
work) -> do
        String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"[" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
jobid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"] << popped job"
        TimeSpec
t1 <- Clock -> IO TimeSpec
getTime Clock
Monotonic
        IO ()
work
        TimeSpec
t2 <- Clock -> IO TimeSpec
getTime Clock
Monotonic
        String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"[" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
jobid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"] job took " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show (TimeSpec -> Double
tsdiff2float (TimeSpec -> TimeSpec -> TimeSpec
diffTimeSpec TimeSpec
t1 TimeSpec
t2)) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" ms"
        IO ()
loop

-- | When the target size is smaller than the original size, this mercilessly
-- and immediately kills some workers. If you have jobs running, they may well
-- be cancelled.
scalePool :: Pool -> Int -> IO ()
scalePool :: Pool -> Int -> IO ()
scalePool (Pool Chan Job
chan MVar (Vector Worker)
listref IORef Int
_) Int
target =
  MVar (Vector Worker)
-> (Vector Worker -> IO (Vector Worker, ())) -> IO ()
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (Vector Worker)
listref ((Vector Worker -> IO (Vector Worker, ())) -> IO ())
-> (Vector Worker -> IO (Vector Worker, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Vector Worker
workers -> do
    -- either tokill == [] or news == []
    let ([Worker]
remain, [Worker]
tokill) = Int -> [Worker] -> ([Worker], [Worker])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
target (Vector Worker -> [Worker]
forall a. Vector a -> [a]
V.toList Vector Worker
workers)
    [Worker] -> (Worker -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Worker]
tokill ((Worker -> IO ()) -> IO ()) -> (Worker -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Worker ThreadId
tid) -> ThreadId -> IO ()
killThread ThreadId
tid
    [Worker]
news <- [Int] -> (Int -> IO Worker) -> IO [Worker]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Vector Worker -> Int
forall a. Vector a -> Int
V.length Vector Worker
workers .. Int
target Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> IO Worker) -> IO [Worker])
-> (Int -> IO Worker) -> IO [Worker]
forall a b. (a -> b) -> a -> b
$ \Int
i -> Chan Job -> Int -> IO Worker
startWorker Chan Job
chan Int
i
    (Vector Worker, ()) -> IO (Vector Worker, ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Worker] -> Vector Worker
forall a. [a] -> Vector a
V.fromList ([Worker]
remain [Worker] -> [Worker] -> [Worker]
forall a. [a] -> [a] -> [a]
++ [Worker]
news), ())

-- | Submit a job to a thread pool.
submitJob :: Pool -> IO () -> IO ()
submitJob :: Pool -> IO () -> IO ()
submitJob (Pool Chan Job
chan MVar (Vector Worker)
_ IORef Int
jidref) IO ()
work = do
  Int
jobid <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
jidref (\Int
i -> (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
i))
  String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"[" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
jobid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"] >> submitJob"
  Chan Job -> Job -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan Job
chan (Int -> IO () -> Job
Job Int
jobid IO ()
work)