{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}

module Torch.Data.Pipeline
  ( -- * Defining a Dataset
    -- $dataset

    -- * Dataset
    Dataset (..),
    DatasetOptions (..),
    datasetOpts,
    Sample (..),

    -- * Dataloading
    streamFromMap,
  )
where

import Control.Concurrent.Async.Lifted
import Control.Concurrent.STM hiding (atomically)
import Control.Monad
import Control.Monad.Base (MonadBase)
import Control.Monad.Cont (ContT)
import Control.Monad.Trans.Control (MonadBaseControl (..))
import Data.IntMap (IntMap)
import qualified Data.IntMap as I
import Data.Set
import Pipes
import Pipes.Concurrent hiding (atomically)
import System.Random
import Torch.Data.Internal

-- $dataset
-- See the 'Torch.Vision' module which implements the MNIST dataset for a good example of how to define a dataset.

-- | The base dataset class. A dataset is capable of returning a sample
-- for a given key, and every 'Dataset' has a known set of keys.
class (Ord k) => Dataset m dataset k sample | dataset -> m, dataset -> sample, dataset -> k where
  getItem :: dataset -> k -> m sample
  keys :: dataset -> Set k

-- | Dataset options used when loading datasets. Specify shuffling behavior, the number of
-- threads to use, and the buffer size used to store retrieved samples in each thread.
data DatasetOptions = DatasetOptions
  { -- | Max number of samples stored in each buffer at a given time.
    DatasetOptions -> Int
dataBufferSize :: Int,
    -- | Number of threads retrieving samples.
    DatasetOptions -> Int
numWorkers :: Int,
    -- | The ordering of samples streamed.
    DatasetOptions -> Sample
shuffle :: Sample
  }

-- | Default 'DatasetOptions'. The 'Int' parameter specifies the
-- number of workers, and sets the buffer size equal to the number of workers.
-- Sampling is sequential.
datasetOpts :: Int -> DatasetOptions
datasetOpts :: Int -> DatasetOptions
datasetOpts Int
numWorkers =
  DatasetOptions
    { dataBufferSize :: Int
dataBufferSize = Int
numWorkers,
      numWorkers :: Int
numWorkers = Int
numWorkers,
      shuffle :: Sample
shuffle = Sample
Sequential
    }

-- | A 'Sample' determines the ordering of samples streamed out of a dataset.
-- You can either order sequentially, or supply a random generator to shuffle samples.
data Sample where
  Sequential :: Sample
  Shuffle :: RandomGen g => g -> Sample

---------------------- Workflow --------------------
-- - make a new map of keys to TVars of samples, possibly shuffled keys, tracking which keys have been sampled
-- - create a TQueue of keys (using pipes-concurrency wrapper)
-- - fork off workers which all pull from the TQueue and sample that key using the dataset,
--   then update the TVar associated with that key
-- have a worker waiting for each successive key to be updated in the list of (key, TVar)

-- | Return a stream of samples from the given dataset, along with a new 'Sample' value.
-- The returned stream contains every sample returned by @'getItem'@ for every key in the set of keys
-- associated with the given dataset. The returned 'Sample' value returns an updated 'Sample' value,
-- this will be identical to the original 'Sample' value if sampling is 'Sequential' but will return a new random number generator
-- if sampling is 'Shuffle'.
streamFromMap ::
  forall m dataset k sample r.
  (Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
  DatasetOptions ->
  dataset ->
  ContT r m (ListT m sample, Sample)
streamFromMap :: forall (m :: * -> *) dataset k sample r.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
DatasetOptions -> dataset -> ContT r m (ListT m sample, Sample)
streamFromMap DatasetOptions {Int
Sample
shuffle :: Sample
numWorkers :: Int
dataBufferSize :: Int
shuffle :: DatasetOptions -> Sample
numWorkers :: DatasetOptions -> Int
dataBufferSize :: DatasetOptions -> Int
..} dataset
dataset = do
  (Output (k, TVar (Maybe sample))
keyOutput, Input (k, TVar (Maybe sample))
keyInput, STM ()
seal) <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Buffer a -> IO (Output a, Input a, STM ())
spawn' forall a. Buffer a
unbounded

  let retrieveSet :: ContT r m [(k, TVar (Maybe sample))]
retrieveSet = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) k sample.
MonadIO m =>
Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet forall a b. (a -> b) -> a -> b
$ forall {k} (m :: k -> *) dataset k (sample :: k).
Dataset m dataset k sample =>
dataset -> Set k
keys dataset
dataset
  ([(k, TVar (Maybe sample))]
keyTVarSet, Sample
updatedSample) <- case Sample
shuffle of
    Sample
Sequential -> (,Sample
Sequential) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ContT r m [(k, TVar (Maybe sample))]
retrieveSet
    Shuffle g
g -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall g. RandomGen g => g -> Sample
Shuffle forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall g a. RandomGen g => g -> [a] -> ([a], g)
fisherYates g
g forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ContT r m [(k, TVar (Maybe sample))]
retrieveSet

  -- fill the queue with each key and associated TVar then seal it
  forall (m :: * -> *) k sample.
MonadBase IO m =>
Output (k, TVar (Maybe sample))
-> [(k, TVar (Maybe sample))] -> m ()
keyQueue Output (k, TVar (Maybe sample))
keyOutput [(k, TVar (Maybe sample))]
keyTVarSet
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically STM ()
seal

  let workers :: m ()
workers = forall (m :: * -> *) dataset k sample.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int -> dataset -> Input (k, TVar (Maybe sample)) -> m ()
runWorkers Int
numWorkers dataset
dataset Input (k, TVar (Maybe sample))
keyInput
      datastream :: Output sample -> m ()
datastream = forall (m :: * -> *) k sample.
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] -> Output sample -> m ()
awaitNextItem [(k, TVar (Maybe sample))]
keyTVarSet
  ListT m sample
listT <- forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
dataBufferSize forall a b. (a -> b) -> a -> b
$ \Output sample
output -> forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m ()
concurrently_ m ()
workers (Output sample -> m ()
datastream Output sample
output)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (ListT m sample
listT, Sample
updatedSample)

runWorkers ::
  (Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
  Int ->
  dataset ->
  Input (k, TVar (Maybe sample)) ->
  m ()
runWorkers :: forall (m :: * -> *) dataset k sample.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int -> dataset -> Input (k, TVar (Maybe sample)) -> m ()
runWorkers Int
numWorkers dataset
dataset Input (k, TVar (Maybe sample))
keyInput = forall (m :: * -> *) a. MonadBaseControl IO m => Int -> m a -> m ()
replicateConcurrently_ Int
numWorkers (forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input (k, TVar (Maybe sample))
keyInput forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (k, TVar (Maybe sample)) () X m ()
runWorker)
  where
    runWorker :: Proxy () (k, TVar (Maybe sample)) () X m ()
runWorker = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
      (k
key, TVar (Maybe sample)
tvar) <- forall (m :: * -> *) a. Functor m => Consumer' a m a
await
      sample
item <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (m :: k -> *) dataset k (sample :: k).
Dataset m dataset k sample =>
dataset -> k -> m sample
getItem dataset
dataset k
key
      forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe sample)
tvar (forall a. a -> Maybe a
Just sample
item)

awaitNextItem ::
  (MonadBase IO m, MonadIO m) =>
  [(k, TVar (Maybe sample))] ->
  Output sample ->
  m ()
awaitNextItem :: forall (m :: * -> *) k sample.
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] -> Output sample -> m ()
awaitNextItem [(k, TVar (Maybe sample))]
tvars Output sample
output = forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [(k, TVar (Maybe sample))]
tvars forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall {a} {y} {b}. Proxy () (a, TVar (Maybe y)) () y m b
readNextItem forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
output
  where
    readNextItem :: Proxy () (a, TVar (Maybe y)) () y m b
readNextItem = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
      (a
_, TVar (Maybe y)
tvar) <- forall (m :: * -> *) a. Functor m => Consumer' a m a
await
      y
item <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
        Maybe y
val <- forall a. TVar a -> STM a
readTVar TVar (Maybe y)
tvar
        case Maybe y
val of
          Maybe y
Nothing -> forall a. STM a
retry
          Just y
item -> forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe y)
tvar forall a. Maybe a
Nothing forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure y
item -- reset the tvar once we get the sample out of it to save memory
      forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield y
item

keyTVarSet :: MonadIO m => Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet :: forall (m :: * -> *) k sample.
MonadIO m =>
Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\k
k -> (,) k
k forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> STM (TVar a)
newTVar forall a. Maybe a
Nothing) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
toList

keyQueue :: MonadBase IO m => Output (k, TVar (Maybe sample)) -> [(k, TVar (Maybe sample))] -> m ()
keyQueue :: forall (m :: * -> *) k sample.
MonadBase IO m =>
Output (k, TVar (Maybe sample))
-> [(k, TVar (Maybe sample))] -> m ()
keyQueue Output (k, TVar (Maybe sample))
keyOutput [(k, TVar (Maybe sample))]
keyTVarSet = forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [(k, TVar (Maybe sample))]
keyTVarSet forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output (k, TVar (Maybe sample))
keyOutput

fisherYatesStep :: RandomGen g => (IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep :: forall g a.
RandomGen g =>
(IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep (IntMap a
m, g
gen) (Int
i, a
x) = ((forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
j a
x forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
i (IntMap a
m forall a. IntMap a -> Int -> a
I.! Int
j)) IntMap a
m, g
gen')
  where
    (Int
j, g
gen') = forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
0, Int
i) g
gen

fisherYates :: RandomGen g => g -> [a] -> ([a], g)
fisherYates :: forall g a. RandomGen g => g -> [a] -> ([a], g)
fisherYates g
gen [] = ([], g
gen)
fisherYates g
gen [a]
l =
  forall {a} {b}. (IntMap a, b) -> ([a], b)
toElems forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Prelude.foldl forall g a.
RandomGen g =>
(IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep (forall {a} {b}. a -> b -> (IntMap a, b)
initial (forall a. [a] -> a
head [a]
l) g
gen) (forall {b}. [b] -> [(Int, b)]
numerate (forall a. [a] -> [a]
tail [a]
l))
  where
    toElems :: (IntMap a, b) -> ([a], b)
toElems (IntMap a
x, b
y) = (forall a. IntMap a -> [a]
I.elems IntMap a
x, b
y)
    numerate :: [b] -> [(Int, b)]
numerate = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
    initial :: a -> b -> (IntMap a, b)
initial a
x b
gen = (forall a. Int -> a -> IntMap a
I.singleton Int
0 a
x, b
gen)