{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Torch.Data.Dataset where
import qualified Control.Foldl as L
import Lens.Family (view)
import Pipes (ListT (Select), Pipe, Producer, enumerate, (>->))
import Pipes.Group (chunksOf, folds)
import Torch.Data.StreamedPipeline
data CollatedDataset m dataset batch collatedBatch = CollatedDataset
{ forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
set :: dataset,
forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
chunkSize :: Int,
forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
collateFn :: Pipe [batch] collatedBatch m ()
}
instance Datastream m seed dataset batch => Datastream m seed (CollatedDataset m dataset batch collatedBatch) collatedBatch where
streamSamples :: CollatedDataset m dataset batch collatedBatch
-> seed -> ListT m collatedBatch
streamSamples CollatedDataset {dataset
Int
Pipe [batch] collatedBatch m ()
collateFn :: Pipe [batch] collatedBatch m ()
chunkSize :: Int
set :: dataset
collateFn :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
chunkSize :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
set :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
..} =
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Pipe [batch] collatedBatch m ()
collateFn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall a. Fold a [a]
L.list
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
chunkSize)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples dataset
set