{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Data.StreamedPipeline
  ( -- * Defining a Datastream
    -- $dataset

    -- * Datastream
    Datastream (..),
    DatastreamOptions (..),
    datastreamOpts,

    -- * Dataloading
    streamFrom,
    streamFrom',

    -- * Reexports
    MonadBase (..),
    MonadBaseControl (..),
  )
where

import Control.Arrow (second)
import Control.Concurrent.Async.Lifted
import Control.Concurrent.STM hiding (atomically)
import Control.Foldl (FoldM)
import qualified Control.Foldl as L
import Control.Monad
import Control.Monad.Base (MonadBase, liftBase)
import Control.Monad.Cont (ContT (..))
import Control.Monad.Trans.Control
import qualified Data.Vector as V
import Pipes
import Pipes.Concurrent hiding (atomically)
import qualified Pipes.Prelude as P
import Torch.Data.Internal

-- $dataset
-- We will show how to retrieve the IMDB dataset as an example datastream.
-- The dataset used here can be found at https://ai.stanford.edu/~amaas/data/sentiment/
--
-- > import Pipes
-- > import qualified Pipes.Safe as Safe
-- > import qualified Pipes.Prelude as P
-- > import System.Directory
-- >
-- > newtype Imdb = Imdb { dataDir :: String }
-- >
-- > data Sentiment = Positive | Negative
-- >
-- > instance (MonadBaseControl IO m, MonadSafe m) => Datastream m Sentiment Imdb (Text, Sentiment) where
-- >   streamSamples Imdb{..} sent = Select $ do
-- >     rawFilePaths <- zip (repeat sent) <$> (liftIO $ listDirectory (dataDir </> sentToPath sent))
-- >     let filePaths = fmap (second $ mappend (dataDir </> sentToPath sent)) rawFilePaths
-- >     for (each filePaths) $ \(rev, fp) -> Safe.withFile fp ReadMode $ \fh ->
-- >       P.zip (PT.fromHandleLn fh) (yield rev)
-- >         where sentToPath Pos = "pos" ++ pure pathSeparator
-- >               sentToPath Neg = "neg" ++ pure pathSeparator
--
-- This streams in movie reviews from each file in either the positive review directory or
-- the negative review directory, depending on the seed value used.
--
-- This highlights a use of seed values that is more interesting than just specifying the thread count, but also has some problems.
-- When running this datastream with either 'streamFrom' or 'streamFrom\'', you need to supply both 'Positive' and 'Negative' values as seeds
-- to retrieve the entire IMDB dataset, and in this case positive and negative reviews will be streamed in concurrently.
-- The problem with designing a datastream in this fashion is you limit the amount of concurrency (2 threads in this case) without
-- duplicating data. Ultimately though seeds should be quite flexible and allow you to design the concurrency how you see fit. Be careful
-- not to use duplicate seed values unless you want duplicate data.

-- | The base datastream class. A dataset returns a stream of samples
-- based on a seed value.
class Monad m => Datastream m seed dataset sample | dataset -> sample where
  streamSamples :: dataset -> seed -> ListT m sample

-- | Datastream options used when looding datastreams. Currently only buffer size is configurable,
-- since thread count is controlled by the number of seeds (see @'streamFrom'@ functions).
newtype DatastreamOptions = DatastreamOptions
  { -- | Max number of samples stored in each buffer at a given time.
    DatastreamOptions -> Int
bufferSize :: Int
  }

-- | Default dataloader options, you should override the fields in this record.
datastreamOpts :: DatastreamOptions
datastreamOpts :: DatastreamOptions
datastreamOpts = DatastreamOptions {bufferSize :: Int
bufferSize = Int
4} -- 4 is relatively arbitrary

-- | Return a stream of samples from the given dataset as a continuation.
-- A stream of samples is generated for every seed in the given stream of seeds, and all of these streams are merged
-- into the output stream in a non-deterministic order (if you need determinism see 'streamFrom\'').
-- Every stream created for each seed value is made in its own thread.
streamFrom ::
  forall sample m dataset seed b.
  (Datastream m seed dataset sample, MonadBaseControl IO m, MonadBase IO m) =>
  DatastreamOptions ->
  dataset ->
  ListT m seed ->
  ContT b m (ListT m sample)
streamFrom :: forall sample (m :: * -> *) dataset seed b.
(Datastream m seed dataset sample, MonadBaseControl IO m,
 MonadBase IO m) =>
DatastreamOptions
-> dataset -> ListT m seed -> ContT b m (ListT m sample)
streamFrom DatastreamOptions {Int
bufferSize :: Int
bufferSize :: DatastreamOptions -> Int
..} dataset
dataset ListT m seed
seeds = forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
bufferSize forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) seed dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m) =>
dataset -> ListT m seed -> Output sample -> m ()
readSamples dataset
dataset ListT m seed
seeds

-- | This function is the same as 'streamFrom' except the seeds are specified as
-- a 'Foldable', and the stream returned has a deterministic ordering. The results
-- from each given seed are interspersed in the order defined by the @'Foldable'@ of seeds.
streamFrom' ::
  forall sample m f dataset seed b.
  (Show sample, Datastream m seed dataset sample, MonadBaseControl IO m, MonadBase IO m, MonadIO m, Foldable f) =>
  DatastreamOptions ->
  dataset ->
  f seed ->
  ContT b m (ListT m sample)
streamFrom' :: forall sample (m :: * -> *) (f :: * -> *) dataset seed b.
(Show sample, Datastream m seed dataset sample,
 MonadBaseControl IO m, MonadBase IO m, MonadIO m, Foldable f) =>
DatastreamOptions
-> dataset -> f seed -> ContT b m (ListT m sample)
streamFrom' DatastreamOptions {Int
bufferSize :: Int
bufferSize :: DatastreamOptions -> Int
..} dataset
dataset f seed
seeds = do
  TVar Int
workerTracker <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. a -> STM (TVar a)
newTVar Int
0
  let consumeSeeds :: Vector (Output sample, Input sample, STM ())
-> Output sample -> Proxy X () () X m ()
consumeSeeds Vector (Output sample, Input sample, STM ())
mailboxes Output sample
o = do
        forall (m :: * -> *) x' x b' b a' c' c.
Functor m =>
Proxy x' x b' b m a'
-> (b -> Proxy x' x c' c m b') -> Proxy x' x c' c m a'
for (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each Vector (Output sample, Input sample, STM ())
mailboxes) forall a b. (a -> b) -> a -> b
$ \(Output sample
_, Input sample
input, STM ()
_) -> forall (m :: * -> *) a.
MonadIO m =>
TVar Int -> Input a -> Producer a m ()
fromInputOnce TVar Int
workerTracker Input sample
input forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
o
        Bool
keepReading <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ (\Int
x -> Int
x forall a. Ord a => a -> a -> Bool
< forall a. Vector a -> Int
V.length Vector (Output sample, Input sample, STM ())
mailboxes) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TVar a -> STM a
readTVar TVar Int
workerTracker
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
keepReading forall a b. (a -> b) -> a -> b
$ Vector (Output sample, Input sample, STM ())
-> Output sample -> Proxy X () () X m ()
consumeSeeds Vector (Output sample, Input sample, STM ())
mailboxes Output sample
o
  forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
bufferSize forall a b. (a -> b) -> a -> b
$ \Output sample
o ->
    forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
liftedBracket
      (forall (f :: * -> *) (m :: * -> *) a b.
(Foldable f, Monad m) =>
FoldM m a b -> f a -> m b
L.foldM forall (m :: * -> *) seed a.
MonadIO m =>
FoldM m seed (Vector (seed, (Output a, Input a, STM ())))
pairSeedWithBuffer f seed
seeds)
      (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> c
third forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd))
      ( \Vector (seed, (Output sample, Input sample, STM ()))
a ->
          let mailboxes :: Vector (Output sample, Input sample, STM ())
mailboxes = forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (seed, (Output sample, Input sample, STM ()))
a
              seedAndOutput :: Vector (seed, Output sample)
seedAndOutput = forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second forall {a} {b} {c}. (a, b, c) -> a
fst3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (seed, (Output sample, Input sample, STM ()))
a
           in forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m ()
concurrently_
                (forall (m :: * -> *) seed (f :: * -> *) dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m,
 MonadIO m, Foldable f) =>
dataset -> f (seed, Output sample) -> m ()
readSamplesDeterministic dataset
dataset Vector (seed, Output sample)
seedAndOutput forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
`liftedFinally` forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> c
third) Vector (Output sample, Input sample, STM ())
mailboxes)
                (forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Vector (Output sample, Input sample, STM ())
-> Output sample -> Proxy X () () X m ()
consumeSeeds Vector (Output sample, Input sample, STM ())
mailboxes Output sample
o) forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
`liftedFinally` forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> c
third) Vector (Output sample, Input sample, STM ())
mailboxes)
      )
  where
    fst3 :: (a, b, c) -> a
fst3 (a
a, b
_, c
_) = a
a
    third :: (a, b, c) -> c
third (a
_, b
_, c
c) = c
c

readSamples ::
  forall m seed dataset sample.
  (Datastream m seed dataset sample, MonadBaseControl IO m) =>
  dataset ->
  ListT m seed ->
  Output sample ->
  m ()
readSamples :: forall (m :: * -> *) seed dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m) =>
dataset -> ListT m seed -> Output sample -> m ()
readSamples dataset
dataset ListT m seed
seeds Output sample
outputBox =
  let this :: Concurrently m () -> seed -> Concurrently m ()
this = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => a -> a -> a
mappend forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
outputBox) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples @m @seed @dataset @sample dataset
dataset
   in forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> x) -> x -> (x -> b) -> Producer a m () -> m b
P.fold Concurrently m () -> seed -> Concurrently m ()
this forall a. Monoid a => a
mempty forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m seed
seeds

readSamplesDeterministic ::
  forall m seed f dataset sample.
  (Datastream m seed dataset sample, MonadBaseControl IO m, MonadIO m, Foldable f) =>
  dataset ->
  f (seed, Output sample) ->
  m ()
readSamplesDeterministic :: forall (m :: * -> *) seed (f :: * -> *) dataset sample.
(Datastream m seed dataset sample, MonadBaseControl IO m,
 MonadIO m, Foldable f) =>
dataset -> f (seed, Output sample) -> m ()
readSamplesDeterministic dataset
dataset f (seed, Output sample)
seeds =
  let this :: Concurrently m () -> (seed, Output sample) -> Concurrently m ()
this Concurrently m ()
c (seed
seed, Output sample
outputBox) =
        forall a. Monoid a => a -> a -> a
mappend Concurrently m ()
c forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
outputBox) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples @m @seed @dataset @sample dataset
dataset seed
seed
   in forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
L.fold (forall a b x. (x -> a -> x) -> x -> (x -> b) -> Fold a b
L.Fold Concurrently m () -> (seed, Output sample) -> Concurrently m ()
this forall a. Monoid a => a
mempty forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently) f (seed, Output sample)
seeds

pairSeedWithBuffer :: MonadIO m => FoldM m seed (V.Vector (seed, (Output a, Input a, STM ())))
pairSeedWithBuffer :: forall (m :: * -> *) seed a.
MonadIO m =>
FoldM m seed (Vector (seed, (Output a, Input a, STM ())))
pairSeedWithBuffer = forall (m :: * -> *) a b r.
Monad m =>
(a -> m b) -> FoldM m b r -> FoldM m a r
L.premapM (\seed
a -> (seed
a,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {a}. m (Output a, Input a, STM ())
makeMailbox) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b. Monad m => Fold a b -> FoldM m a b
L.generalize forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector
  where
    makeMailbox :: m (Output a, Input a, STM ())
makeMailbox = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Buffer a -> IO (Output a, Input a, STM ())
spawn' (forall a. Int -> Buffer a
bounded Int
1)

fromInputOnce :: MonadIO m => TVar Int -> Input a -> Producer a m ()
fromInputOnce :: forall (m :: * -> *) a.
MonadIO m =>
TVar Int -> Input a -> Producer a m ()
fromInputOnce TVar Int
workerTracker Input a
input = do
  Maybe a
ma <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. Input a -> STM (Maybe a)
recv Input a
input
  case Maybe a
ma of
    Maybe a
Nothing -> do
      forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> STM a
readTVar TVar Int
workerTracker forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. TVar a -> a -> STM ()
writeTVar TVar Int
workerTracker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(+) Int
1
      forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Just a
a -> do
      forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield a
a
      forall (m :: * -> *) a. Monad m => a -> m a
return ()