{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Data.StreamedPipeline
(
Datastream (..),
DatastreamOptions (..),
datastreamOpts,
streamFrom,
streamFrom',
MonadBase (..),
MonadBaseControl (..),
)
where
import Control.Arrow (second)
import Control.Concurrent.Async.Lifted
import Control.Concurrent.STM hiding (atomically)
import Control.Foldl (FoldM)
import qualified Control.Foldl as L
import Control.Monad
import Control.Monad.Base (MonadBase, liftBase)
import Control.Monad.Cont (ContT (..))
import Control.Monad.Trans.Control
import qualified Data.Vector as V
import Pipes
import Pipes.Concurrent hiding (atomically)
import qualified Pipes.Prelude as P
import Torch.Data.Internal
class Monad m => Datastream m seed dataset sample | dataset -> sample where
streamSamples :: dataset -> seed -> ListT m sample
newtype DatastreamOptions = DatastreamOptions
{
DatastreamOptions -> Int
bufferSize :: Int
}
datastreamOpts :: DatastreamOptions
datastreamOpts :: DatastreamOptions
datastreamOpts = DatastreamOptions {bufferSize :: Int
bufferSize = Int
4}
streamFrom ::
forall sample m dataset seed b.
(Datastream m seed dataset sample, MonadBaseControl IO m, MonadBase IO m) =>
DatastreamOptions ->
dataset ->
ListT m seed ->
ContT b m (ListT m sample)
streamFrom :: forall sample (m :: * -> *) dataset seed b.
(Datastream m seed dataset sample, MonadBaseControl IO m,
MonadBase IO m) =>
DatastreamOptions
-> dataset -> ListT m seed -> ContT b m (ListT m sample)
streamFrom DatastreamOptions {Int
bufferSize :: Int
bufferSize :: DatastreamOptions -> Int
..} dataset
dataset ListT m seed
seeds = forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
bufferSize forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) seed dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m) =>
dataset -> ListT m seed -> Output sample -> m ()
readSamples dataset
dataset ListT m seed
seeds
streamFrom' ::
forall sample m f dataset seed b.
(Show sample, Datastream m seed dataset sample, MonadBaseControl IO m, MonadBase IO m, MonadIO m, Foldable f) =>
DatastreamOptions ->
dataset ->
f seed ->
ContT b m (ListT m sample)
streamFrom' :: forall sample (m :: * -> *) (f :: * -> *) dataset seed b.
(Show sample, Datastream m seed dataset sample,
MonadBaseControl IO m, MonadBase IO m, MonadIO m, Foldable f) =>
DatastreamOptions
-> dataset -> f seed -> ContT b m (ListT m sample)
streamFrom' DatastreamOptions {Int
bufferSize :: Int
bufferSize :: DatastreamOptions -> Int
..} dataset
dataset f seed
seeds = do
TVar Int
workerTracker <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. a -> STM (TVar a)
newTVar Int
0
let consumeSeeds :: Vector (Output sample, Input sample, STM ())
-> Output sample -> Proxy X () () X m ()
consumeSeeds Vector (Output sample, Input sample, STM ())
mailboxes Output sample
o = do
forall (m :: * -> *) x' x b' b a' c' c.
Functor m =>
Proxy x' x b' b m a'
-> (b -> Proxy x' x c' c m b') -> Proxy x' x c' c m a'
for (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each Vector (Output sample, Input sample, STM ())
mailboxes) forall a b. (a -> b) -> a -> b
$ \(Output sample
_, Input sample
input, STM ()
_) -> forall (m :: * -> *) a.
MonadIO m =>
TVar Int -> Input a -> Producer a m ()
fromInputOnce TVar Int
workerTracker Input sample
input forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
o
Bool
keepReading <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ (\Int
x -> Int
x forall a. Ord a => a -> a -> Bool
< forall a. Vector a -> Int
V.length Vector (Output sample, Input sample, STM ())
mailboxes) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TVar a -> STM a
readTVar TVar Int
workerTracker
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
keepReading forall a b. (a -> b) -> a -> b
$ Vector (Output sample, Input sample, STM ())
-> Output sample -> Proxy X () () X m ()
consumeSeeds Vector (Output sample, Input sample, STM ())
mailboxes Output sample
o
forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
bufferSize forall a b. (a -> b) -> a -> b
$ \Output sample
o ->
forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
liftedBracket
(forall (f :: * -> *) (m :: * -> *) a b.
(Foldable f, Monad m) =>
FoldM m a b -> f a -> m b
L.foldM forall (m :: * -> *) seed a.
MonadIO m =>
FoldM m seed (Vector (seed, (Output a, Input a, STM ())))
pairSeedWithBuffer f seed
seeds)
(forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> c
third forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd))
( \Vector (seed, (Output sample, Input sample, STM ()))
a ->
let mailboxes :: Vector (Output sample, Input sample, STM ())
mailboxes = forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (seed, (Output sample, Input sample, STM ()))
a
seedAndOutput :: Vector (seed, Output sample)
seedAndOutput = forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second forall {a} {b} {c}. (a, b, c) -> a
fst3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (seed, (Output sample, Input sample, STM ()))
a
in forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m ()
concurrently_
(forall (m :: * -> *) seed (f :: * -> *) dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m,
MonadIO m, Foldable f) =>
dataset -> f (seed, Output sample) -> m ()
readSamplesDeterministic dataset
dataset Vector (seed, Output sample)
seedAndOutput forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
`liftedFinally` forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> c
third) Vector (Output sample, Input sample, STM ())
mailboxes)
(forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Vector (Output sample, Input sample, STM ())
-> Output sample -> Proxy X () () X m ()
consumeSeeds Vector (Output sample, Input sample, STM ())
mailboxes Output sample
o) forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
`liftedFinally` forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> c
third) Vector (Output sample, Input sample, STM ())
mailboxes)
)
where
fst3 :: (a, b, c) -> a
fst3 (a
a, b
_, c
_) = a
a
third :: (a, b, c) -> c
third (a
_, b
_, c
c) = c
c
readSamples ::
forall m seed dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m) =>
dataset ->
ListT m seed ->
Output sample ->
m ()
readSamples :: forall (m :: * -> *) seed dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m) =>
dataset -> ListT m seed -> Output sample -> m ()
readSamples dataset
dataset ListT m seed
seeds Output sample
outputBox =
let this :: Concurrently m () -> seed -> Concurrently m ()
this = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => a -> a -> a
mappend forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
outputBox) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples @m @seed @dataset @sample dataset
dataset
in forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> x) -> x -> (x -> b) -> Producer a m () -> m b
P.fold Concurrently m () -> seed -> Concurrently m ()
this forall a. Monoid a => a
mempty forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m seed
seeds
readSamplesDeterministic ::
forall m seed f dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m, MonadIO m, Foldable f) =>
dataset ->
f (seed, Output sample) ->
m ()
readSamplesDeterministic :: forall (m :: * -> *) seed (f :: * -> *) dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m,
MonadIO m, Foldable f) =>
dataset -> f (seed, Output sample) -> m ()
readSamplesDeterministic dataset
dataset f (seed, Output sample)
seeds =
let this :: Concurrently m () -> (seed, Output sample) -> Concurrently m ()
this Concurrently m ()
c (seed
seed, Output sample
outputBox) =
forall a. Monoid a => a -> a -> a
mappend Concurrently m ()
c forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
outputBox) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples @m @seed @dataset @sample dataset
dataset seed
seed
in forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
L.fold (forall a b x. (x -> a -> x) -> x -> (x -> b) -> Fold a b
L.Fold Concurrently m () -> (seed, Output sample) -> Concurrently m ()
this forall a. Monoid a => a
mempty forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently) f (seed, Output sample)
seeds
pairSeedWithBuffer :: MonadIO m => FoldM m seed (V.Vector (seed, (Output a, Input a, STM ())))
pairSeedWithBuffer :: forall (m :: * -> *) seed a.
MonadIO m =>
FoldM m seed (Vector (seed, (Output a, Input a, STM ())))
pairSeedWithBuffer = forall (m :: * -> *) a b r.
Monad m =>
(a -> m b) -> FoldM m b r -> FoldM m a r
L.premapM (\seed
a -> (seed
a,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {a}. m (Output a, Input a, STM ())
makeMailbox) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b. Monad m => Fold a b -> FoldM m a b
L.generalize forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector
where
makeMailbox :: m (Output a, Input a, STM ())
makeMailbox = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Buffer a -> IO (Output a, Input a, STM ())
spawn' (forall a. Int -> Buffer a
bounded Int
1)
fromInputOnce :: MonadIO m => TVar Int -> Input a -> Producer a m ()
fromInputOnce :: forall (m :: * -> *) a.
MonadIO m =>
TVar Int -> Input a -> Producer a m ()
fromInputOnce TVar Int
workerTracker Input a
input = do
Maybe a
ma <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. Input a -> STM (Maybe a)
recv Input a
input
case Maybe a
ma of
Maybe a
Nothing -> do
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> STM a
readTVar TVar Int
workerTracker forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. TVar a -> a -> STM ()
writeTVar TVar Int
workerTracker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(+) Int
1
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Just a
a -> do
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield a
a
forall (m :: * -> *) a. Monad m => a -> m a
return ()