{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Torch.Data.Dataset where

import qualified Control.Foldl as L
import Lens.Family (view)
import Pipes (ListT (Select), Pipe, Producer, enumerate, (>->))
import Pipes.Group (chunksOf, folds)
import Torch.Data.StreamedPipeline

-- | This type is actually not very useful.
-- | It would actually be better to define a transform
-- | on top of another dataset, since then we can do this in parallel
data CollatedDataset m dataset batch collatedBatch = CollatedDataset
  { forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
set :: dataset,
    forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
chunkSize :: Int,
    forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
collateFn :: Pipe [batch] collatedBatch m ()
  }

instance Datastream m seed dataset batch => Datastream m seed (CollatedDataset m dataset batch collatedBatch) collatedBatch where
  streamSamples :: CollatedDataset m dataset batch collatedBatch
-> seed -> ListT m collatedBatch
streamSamples CollatedDataset {dataset
Int
Pipe [batch] collatedBatch m ()
collateFn :: Pipe [batch] collatedBatch m ()
chunkSize :: Int
set :: dataset
collateFn :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
chunkSize :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
set :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
..} =
    forall (m :: * -> *) a. Producer a m () -> ListT m a
Select
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Pipe [batch] collatedBatch m ()
collateFn)
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall a. Fold a [a]
L.list
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
     (Producer a' m x)
     (Producer a m x)
     (FreeT (Producer a' m) m x)
     (FreeT (Producer a m) m x)
chunksOf Int
chunkSize)
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples dataset
set