{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Data.CsvDatastream
( BufferSize,
NamedColumns (..),
CsvDatastream' (..),
CsvDatastream,
CsvDatastreamNamed,
csvDatastream,
tsvDatastream,
FromField (..),
FromRecord (..),
FromNamedRecord (..),
)
where
import qualified Control.Foldl as L
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Char (ord)
import Data.Csv (DecodeOptions (decDelimiter))
import Data.STRef
import Data.Vector (Vector)
import qualified Data.Vector as V
import Lens.Family (view)
import Pipes
import qualified Pipes.ByteString as B
import Pipes.Csv
import Pipes.Group (chunksOf, folds)
import qualified Pipes.Prelude as P
import qualified Pipes.Safe as Safe
import qualified Pipes.Safe.Prelude as Safe
import System.IO (IOMode (ReadMode))
import System.Random
import Torch.Data.StreamedPipeline
data NamedColumns = Unnamed | Named
type BufferSize = Int
data CsvDatastream' batches (named :: NamedColumns) = CsvDatastream'
{
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
filePath :: FilePath,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
delimiter :: !B.Word8,
:: HasHeader,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
batchSize :: Int,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
bufferedShuffle :: Maybe BufferSize,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
dropLast :: Bool
}
type CsvDatastream batches = CsvDatastream' batches Unnamed
type CsvDatastreamNamed batches = CsvDatastream' batches Named
tsvDatastream :: forall (isNamed :: NamedColumns) batches. FilePath -> CsvDatastream' batches isNamed
tsvDatastream :: forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
tsvDatastream FilePath
filePath = (forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
csvDatastream FilePath
filePath) {delimiter :: Word8
delimiter = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Char -> Int
ord Char
'\t'}
csvDatastream :: forall (isNamed :: NamedColumns) batches. FilePath -> CsvDatastream' batches isNamed
csvDatastream :: forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
csvDatastream FilePath
filePath =
CsvDatastream'
{ filePath :: FilePath
filePath = FilePath
filePath,
delimiter :: Word8
delimiter = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Char -> Int
ord Char
',',
hasHeader :: HasHeader
hasHeader = HasHeader
NoHeader,
batchSize :: Int
batchSize = Int
1,
bufferedShuffle :: Maybe Int
bufferedShuffle = forall a. Maybe a
Nothing,
dropLast :: Bool
dropLast = Bool
True
}
instance
( MonadBaseControl IO m,
Safe.MonadSafe m,
FromRecord batch
) =>
Datastream m () (CsvDatastream batch) (Vector batch)
where
streamSamples :: CsvDatastream batch -> () -> ListT m (Vector batch)
streamSamples csv :: CsvDatastream batch
csv@CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
dropLast :: Bool
bufferedShuffle :: Maybe Int
batchSize :: Int
hasHeader :: HasHeader
delimiter :: Word8
filePath :: FilePath
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
..} ()
_ = forall {f :: * -> *} {m :: * -> *} {m :: * -> *} {batches}
{named :: NamedColumns} {x'} {x} {a}.
(Foldable f, MonadSafe m, MonadIO m, MonadBase IO m) =>
CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastream batch
csv (forall (m :: * -> *) a.
(Monad m, FromRecord a) =>
DecodeOptions
-> HasHeader
-> Producer ByteString m ()
-> Producer (Either FilePath a) m ()
decodeWith (DecodeOptions
defaultDecodeOptions {decDelimiter :: Word8
decDelimiter = Word8
delimiter}) HasHeader
hasHeader)
instance
( MonadBaseControl IO m,
Safe.MonadSafe m,
FromNamedRecord batch
) =>
Datastream m () (CsvDatastreamNamed batch) (Vector batch)
where
streamSamples :: CsvDatastreamNamed batch -> () -> ListT m (Vector batch)
streamSamples csv :: CsvDatastreamNamed batch
csv@CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
dropLast :: Bool
bufferedShuffle :: Maybe Int
batchSize :: Int
hasHeader :: HasHeader
delimiter :: Word8
filePath :: FilePath
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
..} ()
_ = forall {f :: * -> *} {m :: * -> *} {m :: * -> *} {batches}
{named :: NamedColumns} {x'} {x} {a}.
(Foldable f, MonadSafe m, MonadIO m, MonadBase IO m) =>
CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastreamNamed batch
csv (forall (m :: * -> *) a.
(Monad m, FromNamedRecord a) =>
DecodeOptions
-> Producer ByteString m () -> Producer (Either FilePath a) m ()
decodeByNameWith (DecodeOptions
defaultDecodeOptions {decDelimiter :: Word8
decDelimiter = Word8
delimiter}))
readCsv :: CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
dropLast :: Bool
bufferedShuffle :: Maybe Int
batchSize :: Int
hasHeader :: HasHeader
delimiter :: Word8
filePath :: FilePath
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
..} Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode = forall (m :: * -> *) a. Producer a m () -> ListT m a
Select forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) r.
MonadSafe m =>
FilePath -> IOMode -> (Handle -> m r) -> m r
Safe.withFile FilePath
filePath IOMode
ReadMode forall a b. (a -> b) -> a -> b
$ \Handle
fh ->
if Bool
dropLast
then Handle -> Proxy X () () (Vector a) m ()
streamRecords Handle
fh forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Functor m => (a -> Bool) -> Pipe a a m r
P.filter (\Vector a
v -> forall a. Vector a -> Int
V.length Vector a
v forall a. Eq a => a -> a -> Bool
== Int
batchSize)
else Handle -> Proxy X () () (Vector a) m ()
streamRecords Handle
fh
where
streamRecords :: Handle -> Proxy X () () (Vector a) m ()
streamRecords Handle
fh = case Maybe Int
bufferedShuffle of
Maybe Int
Nothing -> forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector forall a b. (a -> b) -> a -> b
$ forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
batchSize) forall a b. (a -> b) -> a -> b
$ Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode (forall {m :: * -> *} {x'} {x}.
MonadIO m =>
Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh) forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) (f :: * -> *) a r.
(Functor m, Foldable f) =>
Pipe (f a) a m r
P.concat
Just Int
bufferSize ->
forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector forall a b. (a -> b) -> a -> b
$
forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
batchSize) forall a b. (a -> b) -> a -> b
$
(forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall a. Fold a [a]
L.list forall a b. (a -> b) -> a -> b
$ forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
bufferSize) forall a b. (a -> b) -> a -> b
$ Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode (forall {m :: * -> *} {x'} {x}.
MonadIO m =>
Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh) forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) (f :: * -> *) a r.
(Functor m, Foldable f) =>
Pipe (f a) a m r
P.concat) forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall {y}. Proxy () [y] () y m ()
shuffleRecords
produceLine :: Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh = forall (m :: * -> *).
MonadIO m =>
Int -> Handle -> Producer' ByteString m ()
B.hGetSome Int
1000 Handle
fh
shuffleRecords :: Proxy () [y] () y m ()
shuffleRecords = do
[y]
chunks <- forall (m :: * -> *) a. Functor m => Consumer' a m a
await
StdGen
std <- forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
Torch.Data.StreamedPipeline.liftBase forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. [a] -> StdGen -> ([a], StdGen)
shuffle' [y]
chunks StdGen
std
shuffle' :: [a] -> StdGen -> ([a], StdGen)
shuffle' :: forall a. [a] -> StdGen -> ([a], StdGen)
shuffle' [a]
xs StdGen
gen =
forall a. (forall s. ST s a) -> a
runST
( do
STRef s StdGen
g <- forall a s. a -> ST s (STRef s a)
newSTRef StdGen
gen
let randomRST :: (Int, Int) -> ST s Int
randomRST (Int, Int)
lohi = do
(Int
a, StdGen
s') <- forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int, Int)
lohi) (forall s a. STRef s a -> ST s a
readSTRef STRef s StdGen
g)
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s StdGen
g StdGen
s'
forall (m :: * -> *) a. Monad m => a -> m a
return Int
a
STArray s Int a
ar <- forall a s. Int -> [a] -> ST s (STArray s Int a)
newArray Int
n [a]
xs
[a]
xs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1 .. Int
n] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Int
j <- (Int, Int) -> ST s Int
randomRST (Int
i, Int
n)
a
vi <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ar Int
i
a
vj <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ar Int
j
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ar Int
j a
vi
forall (m :: * -> *) a. Monad m => a -> m a
return a
vj
StdGen
gen' <- forall s a. STRef s a -> ST s a
readSTRef STRef s StdGen
g
forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
xs', StdGen
gen')
)
where
n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
xs
newArray :: Int -> [a] -> ST s (STArray s Int a)
newArray :: forall a s. Int -> [a] -> ST s (STArray s Int a)
newArray Int
n [a]
xs = forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray (Int
1, Int
n) [a]
xs