{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Data.CsvDatastream
  ( BufferSize,
    NamedColumns (..),
    CsvDatastream' (..),
    CsvDatastream,
    CsvDatastreamNamed,
    csvDatastream,
    tsvDatastream,

    -- * Reexports
    FromField (..),
    FromRecord (..),
    FromNamedRecord (..),
  )
where

import qualified Control.Foldl as L
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Char (ord)
import Data.Csv (DecodeOptions (decDelimiter))
import Data.STRef
import Data.Vector (Vector)
import qualified Data.Vector as V
import Lens.Family (view)
import Pipes
import qualified Pipes.ByteString as B
import Pipes.Csv
import Pipes.Group (chunksOf, folds)
import qualified Pipes.Prelude as P
import qualified Pipes.Safe as Safe
import qualified Pipes.Safe.Prelude as Safe
import System.IO (IOMode (ReadMode))
import System.Random
import Torch.Data.StreamedPipeline

data NamedColumns = Unnamed | Named

type BufferSize = Int

-- TODO: implement more options

-- | A CSV datastream. The datastream instance of this type streams
-- samples of `batches` from a CSV file at the specified file path. Batches
-- are yielded in constant memory, but if shuffling is enabled, then there
-- will be at most @'BufferSize'@ records stored in memory.
data CsvDatastream' batches (named :: NamedColumns) = CsvDatastream'
  { -- | CSV file path.
    forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
filePath :: FilePath,
    -- | Column delimiter.
    forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
delimiter :: !B.Word8,
    -- | Does the file have a header?
    forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
hasHeader :: HasHeader,
    -- | Batch size.
    -- , filter     :: Maybe (batches -> Bool)
    forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
batchSize :: Int,
    -- | Buffered shuffle with specified buffer size.
    forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
bufferedShuffle :: Maybe BufferSize,
    -- | Drop the last batch if it is less than batch size.
    forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
dropLast :: Bool
  }

-- | A specialized version of CsvDatastream'. Use this type if you want to decode
-- a CSV file with records defined by the order of the columns.
type CsvDatastream batches = CsvDatastream' batches Unnamed

-- | A specialized version of CsvDatastream'. Use this type if you want to decode
-- a CSV file with records that have @'FromNamedRecord'@ instance. This decodes each field
-- of the record by the corresponding column with the given header name.
type CsvDatastreamNamed batches = CsvDatastream' batches Named

-- | Produce a CsvDatastream' from the given file with default options, and tab separated columns.
tsvDatastream :: forall (isNamed :: NamedColumns) batches. FilePath -> CsvDatastream' batches isNamed
tsvDatastream :: forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
tsvDatastream FilePath
filePath = (forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
csvDatastream FilePath
filePath) {delimiter :: Word8
delimiter = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Char -> Int
ord Char
'\t'}

-- | Produce a CsvDatastream' from the given file with default options, and comma separated columns.
csvDatastream :: forall (isNamed :: NamedColumns) batches. FilePath -> CsvDatastream' batches isNamed
csvDatastream :: forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
csvDatastream FilePath
filePath =
  CsvDatastream'
    { filePath :: FilePath
filePath = FilePath
filePath,
      delimiter :: Word8
delimiter = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Char -> Int
ord Char
',',
      hasHeader :: HasHeader
hasHeader = HasHeader
NoHeader,
      batchSize :: Int
batchSize = Int
1,
      -- , filter = Nothing
      bufferedShuffle :: Maybe Int
bufferedShuffle = forall a. Maybe a
Nothing,
      dropLast :: Bool
dropLast = Bool
True
    }

instance
  ( MonadBaseControl IO m,
    Safe.MonadSafe m,
    FromRecord batch
  ) =>
  Datastream m () (CsvDatastream batch) (Vector batch)
  where
  streamSamples :: CsvDatastream batch -> () -> ListT m (Vector batch)
streamSamples csv :: CsvDatastream batch
csv@CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
dropLast :: Bool
bufferedShuffle :: Maybe Int
batchSize :: Int
hasHeader :: HasHeader
delimiter :: Word8
filePath :: FilePath
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
..} ()
_ = forall {f :: * -> *} {m :: * -> *} {m :: * -> *} {batches}
       {named :: NamedColumns} {x'} {x} {a}.
(Foldable f, MonadSafe m, MonadIO m, MonadBase IO m) =>
CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastream batch
csv (forall (m :: * -> *) a.
(Monad m, FromRecord a) =>
DecodeOptions
-> HasHeader
-> Producer ByteString m ()
-> Producer (Either FilePath a) m ()
decodeWith (DecodeOptions
defaultDecodeOptions {decDelimiter :: Word8
decDelimiter = Word8
delimiter}) HasHeader
hasHeader)

instance
  ( MonadBaseControl IO m,
    Safe.MonadSafe m,
    FromNamedRecord batch
  ) =>
  Datastream m () (CsvDatastreamNamed batch) (Vector batch)
  where
  streamSamples :: CsvDatastreamNamed batch -> () -> ListT m (Vector batch)
streamSamples csv :: CsvDatastreamNamed batch
csv@CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
dropLast :: Bool
bufferedShuffle :: Maybe Int
batchSize :: Int
hasHeader :: HasHeader
delimiter :: Word8
filePath :: FilePath
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
..} ()
_ = forall {f :: * -> *} {m :: * -> *} {m :: * -> *} {batches}
       {named :: NamedColumns} {x'} {x} {a}.
(Foldable f, MonadSafe m, MonadIO m, MonadBase IO m) =>
CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastreamNamed batch
csv (forall (m :: * -> *) a.
(Monad m, FromNamedRecord a) =>
DecodeOptions
-> Producer ByteString m () -> Producer (Either FilePath a) m ()
decodeByNameWith (DecodeOptions
defaultDecodeOptions {decDelimiter :: Word8
decDelimiter = Word8
delimiter}))

readCsv :: CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
dropLast :: Bool
bufferedShuffle :: Maybe Int
batchSize :: Int
hasHeader :: HasHeader
delimiter :: Word8
filePath :: FilePath
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
..} Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode = forall (m :: * -> *) a. Producer a m () -> ListT m a
Select forall a b. (a -> b) -> a -> b
$
  forall (m :: * -> *) r.
MonadSafe m =>
FilePath -> IOMode -> (Handle -> m r) -> m r
Safe.withFile FilePath
filePath IOMode
ReadMode forall a b. (a -> b) -> a -> b
$ \Handle
fh ->
    -- this quietly discards errors in decoding right now, probably would like to log this
    if Bool
dropLast
      then Handle -> Proxy X () () (Vector a) m ()
streamRecords Handle
fh forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Functor m => (a -> Bool) -> Pipe a a m r
P.filter (\Vector a
v -> forall a. Vector a -> Int
V.length Vector a
v forall a. Eq a => a -> a -> Bool
== Int
batchSize)
      else Handle -> Proxy X () () (Vector a) m ()
streamRecords Handle
fh
  where
    streamRecords :: Handle -> Proxy X () () (Vector a) m ()
streamRecords Handle
fh = case Maybe Int
bufferedShuffle of
      Maybe Int
Nothing -> forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector forall a b. (a -> b) -> a -> b
$ forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
     (Producer a' m x)
     (Producer a m x)
     (FreeT (Producer a' m) m x)
     (FreeT (Producer a m) m x)
chunksOf Int
batchSize) forall a b. (a -> b) -> a -> b
$ Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode (forall {m :: * -> *} {x'} {x}.
MonadIO m =>
Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh) forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) (f :: * -> *) a r.
(Functor m, Foldable f) =>
Pipe (f a) a m r
P.concat
      Just Int
bufferSize ->
        forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector forall a b. (a -> b) -> a -> b
$
          forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
     (Producer a' m x)
     (Producer a m x)
     (FreeT (Producer a' m) m x)
     (FreeT (Producer a m) m x)
chunksOf Int
batchSize) forall a b. (a -> b) -> a -> b
$
            (forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds forall a. Fold a [a]
L.list forall a b. (a -> b) -> a -> b
$ forall a s t b. FoldLike a s t a b -> s -> a
view (forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
     (Producer a' m x)
     (Producer a m x)
     (FreeT (Producer a' m) m x)
     (FreeT (Producer a m) m x)
chunksOf Int
bufferSize) forall a b. (a -> b) -> a -> b
$ Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode (forall {m :: * -> *} {x'} {x}.
MonadIO m =>
Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh) forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) (f :: * -> *) a r.
(Functor m, Foldable f) =>
Pipe (f a) a m r
P.concat) forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall {y}. Proxy () [y] () y m ()
shuffleRecords
    -- what's a good default chunk size?
    produceLine :: Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh = forall (m :: * -> *).
MonadIO m =>
Int -> Handle -> Producer' ByteString m ()
B.hGetSome Int
1000 Handle
fh
    -- probably want a cleaner way of reyielding these chunks
    shuffleRecords :: Proxy () [y] () y m ()
shuffleRecords = do
      [y]
chunks <- forall (m :: * -> *) a. Functor m => Consumer' a m a
await
      StdGen
std <- forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
Torch.Data.StreamedPipeline.liftBase forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. [a] -> StdGen -> ([a], StdGen)
shuffle' [y]
chunks StdGen
std

--  https://wiki.haskell.org/Random_shuffle
shuffle' :: [a] -> StdGen -> ([a], StdGen)
shuffle' :: forall a. [a] -> StdGen -> ([a], StdGen)
shuffle' [a]
xs StdGen
gen =
  forall a. (forall s. ST s a) -> a
runST
    ( do
        STRef s StdGen
g <- forall a s. a -> ST s (STRef s a)
newSTRef StdGen
gen
        let randomRST :: (Int, Int) -> ST s Int
randomRST (Int, Int)
lohi = do
              (Int
a, StdGen
s') <- forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int, Int)
lohi) (forall s a. STRef s a -> ST s a
readSTRef STRef s StdGen
g)
              forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s StdGen
g StdGen
s'
              forall (m :: * -> *) a. Monad m => a -> m a
return Int
a
        STArray s Int a
ar <- forall a s. Int -> [a] -> ST s (STArray s Int a)
newArray Int
n [a]
xs
        [a]
xs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1 .. Int
n] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          Int
j <- (Int, Int) -> ST s Int
randomRST (Int
i, Int
n)
          a
vi <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ar Int
i
          a
vj <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ar Int
j
          forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ar Int
j a
vi
          forall (m :: * -> *) a. Monad m => a -> m a
return a
vj
        StdGen
gen' <- forall s a. STRef s a -> ST s a
readSTRef STRef s StdGen
g
        forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
xs', StdGen
gen')
    )
  where
    n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
xs
    newArray :: Int -> [a] -> ST s (STArray s Int a)
    newArray :: forall a s. Int -> [a] -> ST s (STArray s Int a)
newArray Int
n [a]
xs = forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray (Int
1, Int
n) [a]
xs