{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.Data.Utils
  ( pmap,
    pmap',
    pmapGroup,
    bufferedCollate,
    collate,
    enumerateData,
    CachedDataset,
    cache,
  )
where

import qualified Control.Foldl as L
import Control.Monad.Cont
import Control.Monad.Trans.Control
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as M
import qualified Data.Set as S
import Lens.Family
import Pipes
import Pipes.Concurrent
import Pipes.Group
import qualified Pipes.Prelude as P
import Torch.Data.Internal
import Torch.Data.Pipeline

-- | Run a map function in parallel over the given stream.
pmap :: (MonadIO m, MonadBaseControl IO m) => Buffer b -> (a -> b) -> ListT m a -> ContT r m (ListT m b)
pmap :: forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b -> (a -> b) -> ListT m a -> ContT r m (ListT m b)
pmap Buffer b
buffer a -> b
f ListT m a
prod = forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ \ListT m b -> m r
cont ->
  forall a b. (a, b) -> b
snd
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted
      Buffer b
buffer
      (\Output b
output -> forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m a
prod forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a b r. Functor m => (a -> b) -> Pipe a b m r
P.map a -> b
f forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output b
output)
      (\Input b
input -> ListT m b -> m r
cont forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Producer a m () -> ListT m a
Select forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input b
input)

-- | Run a pipe in parallel over the given stream.
pmap' :: (MonadIO m, MonadBaseControl IO m) => Buffer b -> Pipe a b m () -> ListT m a -> ContT r m (ListT m b)
pmap' :: forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b -> Pipe a b m () -> ListT m a -> ContT r m (ListT m b)
pmap' Buffer b
buffer Pipe a b m ()
f ListT m a
prod = forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ \ListT m b -> m r
cont ->
  forall a b. (a, b) -> b
snd
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted
      Buffer b
buffer
      (\Output b
output -> forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m a
prod forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Pipe a b m ()
f forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output b
output)
      (\Input b
input -> ListT m b -> m r
cont forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Producer a m () -> ListT m a
Select forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input b
input)

-- | Map a ListT transform over the given the stream in parallel. This should be useful
-- for using functions which groups elements of a stream and yields them downstream.
pmapGroup :: (MonadIO m, MonadBaseControl IO m) => Buffer b -> (ListT m a -> ListT m b) -> ListT m a -> ContT r m (ListT m b)
pmapGroup :: forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b
-> (ListT m a -> ListT m b) -> ListT m a -> ContT r m (ListT m b)
pmapGroup Buffer b
buffer ListT m a -> ListT m b
f ListT m a
prod = forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ \ListT m b -> m r
cont ->
  forall a b. (a, b) -> b
snd
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted
      Buffer b
buffer
      (\Output b
output -> forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate (ListT m a -> ListT m b
f ListT m a
prod) forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output b
output)
      (\Input b
input -> ListT m b -> m r
cont forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Producer a m () -> ListT m a
Select forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input b
input)

-- | Enumerate the given stream, zipping each element with an index.
enumerateData :: Monad m => ListT m a -> Producer (a, Int) m ()
enumerateData :: forall (m :: * -> *) a.
Monad m =>
ListT m a -> Producer (a, Int) m ()
enumerateData ListT m a
input = forall (m :: * -> *) a r b x' x.
Monad m =>
Producer a m r -> Producer b m r -> Proxy x' x () (a, b) m r
P.zip (forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m a
input) (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [Int
0 ..])

-- | Run a given batching function in parallel. See 'collate' for how the
-- given samples are batched.
bufferedCollate :: (MonadIO m, MonadBaseControl IO m) => Buffer batch -> Int -> ([sample] -> Maybe batch) -> ListT m sample -> ContT r m (ListT m batch)
bufferedCollate :: forall (m :: * -> *) batch sample r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer batch
-> Int
-> ([sample] -> Maybe batch)
-> ListT m sample
-> ContT r m (ListT m batch)
bufferedCollate Buffer batch
buffer Int
batchSize [sample] -> Maybe batch
collateFn = forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b
-> (ListT m a -> ListT m b) -> ListT m a -> ContT r m (ListT m b)
pmapGroup Buffer batch
buffer (forall (m :: * -> *) sample batch.
Monad m =>
Int -> ([sample] -> Maybe batch) -> ListT m sample -> ListT m batch
collate Int
batchSize [sample] -> Maybe batch
collateFn)

-- | Run a batching function with integer batch size over the given stream. The elements of the stream are
-- split into lists of the given batch size and are collated with the given function. Only Just values are yielded
-- downstream. If the last chunk of samples is less than the given batch size then the batching function will be passed a list
-- of length less than batch size.
collate :: Monad m => Int -> ([sample] -> Maybe batch) -> ListT m sample -> ListT m batch
collate :: forall (m :: * -> *) sample batch.
Monad m =>
Int -> ([sample] -> Maybe batch) -> ListT m sample -> ListT m batch
collate Int
batchSize [sample] -> Maybe batch
collateFn = forall (m :: * -> *) a. Producer a m () -> ListT m a
Select forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) (t :: * -> *) a b r.
(Functor m, Foldable t) =>
(a -> t b) -> Pipe a b m r
P.mapFoldable [sample] -> Maybe batch
collateFn) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall a. Fold a [a]
L.list forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
     (Producer a' m x)
     (Producer a m x)
     (FreeT (Producer a' m) m x)
     (FreeT (Producer a m) m x)
chunksOf Int
batchSize) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate

-- | An In-Memory cached dataset. See the 'cache' function for
-- how to create a cached dataset.
newtype CachedDataset (m :: * -> *) sample = CachedDataset {forall (m :: * -> *) sample.
CachedDataset m sample -> IntMap sample
cached :: IntMap sample}

-- | Enumerate a given stream and store it as a 'CachedDataset'. This function should
-- be used after a time consuming preprocessing pipeline and used in subsequent epochs
-- to avoid repeating the preprocessing pipeline.
cache :: Monad m => ListT m sample -> m (CachedDataset m sample)
cache :: forall (m :: * -> *) sample.
Monad m =>
ListT m sample -> m (CachedDataset m sample)
cache ListT m sample
datastream = forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> x) -> x -> (x -> b) -> Producer a m () -> m b
P.fold forall {a}. (IntMap a, Int) -> a -> (IntMap a, Int)
step forall {a}. (IntMap a, Int)
begin forall {sample} {b} {m :: * -> *}.
(IntMap sample, b) -> CachedDataset m sample
done forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate forall a b. (a -> b) -> a -> b
$ ListT m sample
datastream
  where
    step :: (IntMap a, Int) -> a -> (IntMap a, Int)
step (IntMap a
cacheMap, Int
ix) a
sample = (forall a. Int -> a -> IntMap a -> IntMap a
M.insert Int
ix a
sample IntMap a
cacheMap, Int
ix forall a. Num a => a -> a -> a
+ Int
1)
    begin :: (IntMap a, Int)
begin = (forall a. IntMap a
M.empty, Int
0)
    done :: (IntMap sample, b) -> CachedDataset m sample
done = forall (m :: * -> *) sample.
IntMap sample -> CachedDataset m sample
CachedDataset forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst

instance Applicative m => Dataset m (CachedDataset m sample) Int sample where
  getItem :: CachedDataset m sample -> Int -> m sample
getItem CachedDataset {IntMap sample
cached :: IntMap sample
cached :: forall (m :: * -> *) sample.
CachedDataset m sample -> IntMap sample
..} Int
key = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ IntMap sample
cached forall a. IntMap a -> Int -> a
M.! Int
key
  keys :: CachedDataset m sample -> Set Int
keys CachedDataset {IntMap sample
cached :: IntMap sample
cached :: forall (m :: * -> *) sample.
CachedDataset m sample -> IntMap sample
..} = forall a. Eq a => [a] -> Set a
S.fromAscList [Int
0 .. forall a. IntMap a -> Int
M.size IntMap sample
cached]