{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
module Torch.Data.Pipeline
(
Dataset (..),
DatasetOptions (..),
datasetOpts,
Sample (..),
streamFromMap,
)
where
import Control.Concurrent.Async.Lifted
import Control.Concurrent.STM hiding (atomically)
import Control.Monad
import Control.Monad.Base (MonadBase)
import Control.Monad.Cont (ContT)
import Control.Monad.Trans.Control (MonadBaseControl (..))
import Data.IntMap (IntMap)
import qualified Data.IntMap as I
import Data.Set
import Pipes
import Pipes.Concurrent hiding (atomically)
import System.Random
import Torch.Data.Internal
class (Ord k) => Dataset m dataset k sample | dataset -> m, dataset -> sample, dataset -> k where
getItem :: dataset -> k -> m sample
keys :: dataset -> Set k
data DatasetOptions = DatasetOptions
{
DatasetOptions -> Int
dataBufferSize :: Int,
DatasetOptions -> Int
numWorkers :: Int,
DatasetOptions -> Sample
shuffle :: Sample
}
datasetOpts :: Int -> DatasetOptions
datasetOpts :: Int -> DatasetOptions
datasetOpts Int
numWorkers =
DatasetOptions
{ dataBufferSize :: Int
dataBufferSize = Int
numWorkers,
numWorkers :: Int
numWorkers = Int
numWorkers,
shuffle :: Sample
shuffle = Sample
Sequential
}
data Sample where
Sequential :: Sample
Shuffle :: RandomGen g => g -> Sample
streamFromMap ::
forall m dataset k sample r.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
DatasetOptions ->
dataset ->
ContT r m (ListT m sample, Sample)
streamFromMap :: forall (m :: * -> *) dataset k sample r.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
DatasetOptions -> dataset -> ContT r m (ListT m sample, Sample)
streamFromMap DatasetOptions {Int
Sample
shuffle :: Sample
numWorkers :: Int
dataBufferSize :: Int
shuffle :: DatasetOptions -> Sample
numWorkers :: DatasetOptions -> Int
dataBufferSize :: DatasetOptions -> Int
..} dataset
dataset = do
(Output (k, TVar (Maybe sample))
keyOutput, Input (k, TVar (Maybe sample))
keyInput, STM ()
seal) <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Buffer a -> IO (Output a, Input a, STM ())
spawn' forall a. Buffer a
unbounded
let retrieveSet :: ContT r m [(k, TVar (Maybe sample))]
retrieveSet = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) k sample.
MonadIO m =>
Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet forall a b. (a -> b) -> a -> b
$ forall {k} (m :: k -> *) dataset k (sample :: k).
Dataset m dataset k sample =>
dataset -> Set k
keys dataset
dataset
([(k, TVar (Maybe sample))]
keyTVarSet, Sample
updatedSample) <- case Sample
shuffle of
Sample
Sequential -> (,Sample
Sequential) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ContT r m [(k, TVar (Maybe sample))]
retrieveSet
Shuffle g
g -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall g. RandomGen g => g -> Sample
Shuffle forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall g a. RandomGen g => g -> [a] -> ([a], g)
fisherYates g
g forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ContT r m [(k, TVar (Maybe sample))]
retrieveSet
forall (m :: * -> *) k sample.
MonadBase IO m =>
Output (k, TVar (Maybe sample))
-> [(k, TVar (Maybe sample))] -> m ()
keyQueue Output (k, TVar (Maybe sample))
keyOutput [(k, TVar (Maybe sample))]
keyTVarSet
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically STM ()
seal
let workers :: m ()
workers = forall (m :: * -> *) dataset k sample.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int -> dataset -> Input (k, TVar (Maybe sample)) -> m ()
runWorkers Int
numWorkers dataset
dataset Input (k, TVar (Maybe sample))
keyInput
datastream :: Output sample -> m ()
datastream = forall (m :: * -> *) k sample.
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] -> Output sample -> m ()
awaitNextItem [(k, TVar (Maybe sample))]
keyTVarSet
ListT m sample
listT <- forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
dataBufferSize forall a b. (a -> b) -> a -> b
$ \Output sample
output -> forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m ()
concurrently_ m ()
workers (Output sample -> m ()
datastream Output sample
output)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ListT m sample
listT, Sample
updatedSample)
runWorkers ::
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int ->
dataset ->
Input (k, TVar (Maybe sample)) ->
m ()
runWorkers :: forall (m :: * -> *) dataset k sample.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int -> dataset -> Input (k, TVar (Maybe sample)) -> m ()
runWorkers Int
numWorkers dataset
dataset Input (k, TVar (Maybe sample))
keyInput = forall (m :: * -> *) a. MonadBaseControl IO m => Int -> m a -> m ()
replicateConcurrently_ Int
numWorkers (forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input (k, TVar (Maybe sample))
keyInput forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (k, TVar (Maybe sample)) () X m ()
runWorker)
where
runWorker :: Proxy () (k, TVar (Maybe sample)) () X m ()
runWorker = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
(k
key, TVar (Maybe sample)
tvar) <- forall (m :: * -> *) a. Functor m => Consumer' a m a
await
sample
item <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (m :: k -> *) dataset k (sample :: k).
Dataset m dataset k sample =>
dataset -> k -> m sample
getItem dataset
dataset k
key
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe sample)
tvar (forall a. a -> Maybe a
Just sample
item)
awaitNextItem ::
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] ->
Output sample ->
m ()
awaitNextItem :: forall (m :: * -> *) k sample.
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] -> Output sample -> m ()
awaitNextItem [(k, TVar (Maybe sample))]
tvars Output sample
output = forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [(k, TVar (Maybe sample))]
tvars forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall {a} {y} {b}. Proxy () (a, TVar (Maybe y)) () y m b
readNextItem forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
output
where
readNextItem :: Proxy () (a, TVar (Maybe y)) () y m b
readNextItem = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
(a
_, TVar (Maybe y)
tvar) <- forall (m :: * -> *) a. Functor m => Consumer' a m a
await
y
item <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
Maybe y
val <- forall a. TVar a -> STM a
readTVar TVar (Maybe y)
tvar
case Maybe y
val of
Maybe y
Nothing -> forall a. STM a
retry
Just y
item -> forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe y)
tvar forall a. Maybe a
Nothing forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure y
item
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield y
item
keyTVarSet :: MonadIO m => Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet :: forall (m :: * -> *) k sample.
MonadIO m =>
Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\k
k -> (,) k
k forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> STM (TVar a)
newTVar forall a. Maybe a
Nothing) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
toList
keyQueue :: MonadBase IO m => Output (k, TVar (Maybe sample)) -> [(k, TVar (Maybe sample))] -> m ()
keyQueue :: forall (m :: * -> *) k sample.
MonadBase IO m =>
Output (k, TVar (Maybe sample))
-> [(k, TVar (Maybe sample))] -> m ()
keyQueue Output (k, TVar (Maybe sample))
keyOutput [(k, TVar (Maybe sample))]
keyTVarSet = forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [(k, TVar (Maybe sample))]
keyTVarSet forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output (k, TVar (Maybe sample))
keyOutput
fisherYatesStep :: RandomGen g => (IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep :: forall g a.
RandomGen g =>
(IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep (IntMap a
m, g
gen) (Int
i, a
x) = ((forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
j a
x forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
i (IntMap a
m forall a. IntMap a -> Int -> a
I.! Int
j)) IntMap a
m, g
gen')
where
(Int
j, g
gen') = forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
0, Int
i) g
gen
fisherYates :: RandomGen g => g -> [a] -> ([a], g)
fisherYates :: forall g a. RandomGen g => g -> [a] -> ([a], g)
fisherYates g
gen [] = ([], g
gen)
fisherYates g
gen [a]
l =
forall {a} {b}. (IntMap a, b) -> ([a], b)
toElems forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Prelude.foldl forall g a.
RandomGen g =>
(IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep (forall {a} {b}. a -> b -> (IntMap a, b)
initial (forall a. [a] -> a
head [a]
l) g
gen) (forall {b}. [b] -> [(Int, b)]
numerate (forall a. [a] -> [a]
tail [a]
l))
where
toElems :: (IntMap a, b) -> ([a], b)
toElems (IntMap a
x, b
y) = (forall a. IntMap a -> [a]
I.elems IntMap a
x, b
y)
numerate :: [b] -> [(Int, b)]
numerate = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
initial :: a -> b -> (IntMap a, b)
initial a
x b
gen = (forall a. Int -> a -> IntMap a
I.singleton Int
0 a
x, b
gen)