{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Vision where

import qualified Codec.Compression.GZip as GZip
import Control.Monad (forM_)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import qualified Data.ByteString.Lazy as BS.Lazy
import Data.Kind
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Ptr as F
import GHC.Exts (IsList (fromList))
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import Torch.Data.Pipeline
import qualified Torch.Device as D
import Torch.Internal.Cast
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Tensor as D
import qualified Torch.TensorOptions as D
import Torch.Typed.Auxiliary
import Torch.Typed.Functional
import Torch.Typed.Tensor

data MNIST (m :: Type -> Type) (device :: (D.DeviceType, Nat)) (batchSize :: Nat) = MNIST {forall (m :: Type -> Type) (device :: (DeviceType, Nat))
       (batchSize :: Nat).
MNIST m device batchSize -> MnistData
mnistData :: MnistData}

instance
  (KnownNat batchSize, KnownDevice device, Applicative m) =>
  Dataset m (MNIST m device batchSize) Int (Tensor device 'D.Float '[batchSize, 784], Tensor device 'D.Int64 '[batchSize])
  where
  getItem :: MNIST m device batchSize
-> Int
-> m (Tensor device 'Float '[batchSize, 784],
      Tensor device 'Int64 '[batchSize])
getItem MNIST {MnistData
mnistData :: MnistData
mnistData :: forall (m :: Type -> Type) (device :: (DeviceType, Nat))
       (batchSize :: Nat).
MNIST m device batchSize -> MnistData
..} Int
ix =
    let batchSize :: Int
batchSize = forall (n :: Nat). KnownNat n => Int
natValI @batchSize
        indexes :: [Int]
indexes = [Int
ix forall a. Num a => a -> a -> a
* Int
batchSize .. (Int
ix forall a. Num a => a -> a -> a
+ Int
1) forall a. Num a => a -> a -> a
* Int
batchSize forall a. Num a => a -> a -> a
- Int
1]
        imgs :: CPUTensor 'Float '[batchSize, DataDim]
imgs = forall (n :: Nat).
KnownNat n =>
MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim]
getImages @batchSize MnistData
mnistData [Int]
indexes
        labels :: CPUTensor 'Int64 '[batchSize]
labels = forall (n :: Nat).
KnownNat n =>
MnistData -> [Int] -> CPUTensor 'Int64 '[n]
getLabels @batchSize MnistData
mnistData [Int]
indexes
     in forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice @device CPUTensor 'Float '[batchSize, DataDim]
imgs, forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice @device CPUTensor 'Int64 '[batchSize]
labels)

  keys :: MNIST m device batchSize -> Set Int
keys MNIST {MnistData
mnistData :: MnistData
mnistData :: forall (m :: Type -> Type) (device :: (DeviceType, Nat))
       (batchSize :: Nat).
MNIST m device batchSize -> MnistData
..} = forall l. IsList l => [Item l] -> l
fromList [Int
0 .. MnistData -> Int
Torch.Typed.Vision.length MnistData
mnistData forall a. Integral a => a -> a -> a
`Prelude.div` (forall (n :: Nat). KnownNat n => Int
natValI @batchSize) forall a. Num a => a -> a -> a
- Int
1]

data MnistData = MnistData
  { MnistData -> ByteString
images :: BS.ByteString,
    MnistData -> ByteString
labels :: BS.ByteString
  }

type Rows = 28

type Cols = 28

type DataDim = Rows * Cols

type ClassDim = 10

getLabels ::
  forall n. KnownNat n => MnistData -> [Int] -> CPUTensor 'D.Int64 '[n]
getLabels :: forall (n :: Nat).
KnownNat n =>
MnistData -> [Int] -> CPUTensor 'Int64 '[n]
getLabels MnistData
mnist [Int]
imageIdxs =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (MnistData -> Int -> Int
getLabel MnistData
mnist) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
take (forall (n :: Nat). KnownNat n => Int
natValI @n) forall a b. (a -> b) -> a -> b
$ [Int]
imageIdxs

getLabel :: MnistData -> Int -> Int
getLabel :: MnistData -> Int -> Int
getLabel MnistData
mnist Int
imageIdx =
  forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ HasCallStack => ByteString -> Int -> Word8
BS.index (MnistData -> ByteString
labels MnistData
mnist) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
imageIdx forall a. Num a => a -> a -> a
+ Int
8)

getImage :: MnistData -> Int -> CPUTensor 'D.Float '[DataDim]
getImage :: MnistData -> Int -> CPUTensor 'Float '[DataDim]
getImage MnistData
mnist Int
imageIdx =
  let imageBS :: [Float]
imageBS =
        [ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$
            HasCallStack => ByteString -> Int -> Word8
BS.index
              (MnistData -> ByteString
images MnistData
mnist)
              (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
imageIdx forall a. Num a => a -> a -> a
* Int
28 forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 forall a. Num a => a -> a -> a
+ Int
16 forall a. Num a => a -> a -> a
+ Int
r)
          | Int
r <- [Int
0 .. Int
28 forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 forall a. Num a => a -> a -> a
- Int
1]
        ] ::
          [Float]
      (CPUTensor 'Float '[DataDim]
tensor :: CPUTensor 'D.Float '[DataDim]) =
        forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. TensorLike a => a -> Tensor
D.asTensor [Float]
imageBS
   in CPUTensor 'Float '[DataDim]
tensor

getImages' ::
  forall n.
  KnownNat n =>
  MnistData ->
  [Int] ->
  CPUTensor 'D.Float '[n, DataDim]
getImages' :: forall (n :: Nat).
KnownNat n =>
MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim]
getImages' MnistData
mnist [Int]
imageIdxs =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
    forall a. TensorLike a => a -> Tensor
D.asTensor forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> b) -> [a] -> [b]
map Int -> [Float]
image forall a b. (a -> b) -> a -> b
$
        forall a. Int -> [a] -> [a]
take
          (forall (n :: Nat). KnownNat n => Int
natValI @n)
          [Int]
imageIdxs
  where
    image :: Int -> [Float]
image Int
idx =
      [ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$
          HasCallStack => ByteString -> Int -> Word8
BS.index (MnistData -> ByteString
images MnistData
mnist) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
idx forall a. Num a => a -> a -> a
* Int
28 forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 forall a. Num a => a -> a -> a
+ Int
16 forall a. Num a => a -> a -> a
+ Int
r)
        | Int
r <- [Int
0 .. Int
28 forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 forall a. Num a => a -> a -> a
- Int
1]
      ] ::
        [Float]

getImages ::
  forall n.
  KnownNat n =>
  MnistData ->
  [Int] ->
  CPUTensor 'D.Float '[n, DataDim]
getImages :: forall (n :: Nat).
KnownNat n =>
MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim]
getImages MnistData
mnist [Int]
imageIdxs = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    let (BSI.PS ForeignPtr Word8
fptr Int
off Int
len) = MnistData -> ByteString
images MnistData
mnist
    Tensor
t <-
      (forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.empty_lo :: [Int] -> D.TensorOptions -> IO D.Tensor)
        [forall (n :: Nat). KnownNat n => Int
natValI @n, forall (n :: Nat). KnownNat n => Int
natValI @DataDim]
        (DType -> TensorOptions -> TensorOptions
D.withDType DType
D.UInt8 TensorOptions
D.defaultOpts)
    forall a. Tensor -> (Ptr () -> IO a) -> IO a
D.withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
      forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr2 -> do
        forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. ((forall (n :: Nat). KnownNat n => Int
natValI @n) forall a. Num a => a -> a -> a
-Int
1)] [Int]
imageIdxs) forall a b. (a -> b) -> a -> b
$ \(Int
i, Int
idx) -> do
          Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BSI.memcpy
            (forall a b. Ptr a -> Int -> Ptr b
F.plusPtr Ptr ()
ptr1 ((forall (n :: Nat). KnownNat n => Int
natValI @DataDim) forall a. Num a => a -> a -> a
* Int
i))
            (forall a b. Ptr a -> Int -> Ptr b
F.plusPtr Ptr Word8
ptr2 (Int
off forall a. Num a => a -> a -> a
+ Int
16 forall a. Num a => a -> a -> a
+ (forall (n :: Nat). KnownNat n => Int
natValI @DataDim) forall a. Num a => a -> a -> a
* Int
idx))
            (forall (n :: Nat). KnownNat n => Int
natValI @DataDim)
    forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. HasTypes a Tensor => DType -> a -> a
D.toType DType
D.Float Tensor
t

length :: MnistData -> Int
length :: MnistData -> Int
length MnistData
mnist = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length (MnistData -> ByteString
labels MnistData
mnist) forall a. Num a => a -> a -> a
- Int
8

decompressFile :: String -> String -> IO BS.ByteString
decompressFile :: FilePath -> FilePath -> IO ByteString
decompressFile FilePath
path FilePath
file = ByteString -> ByteString
decompress' forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath -> IO ByteString
BS.readFile (FilePath
path forall a. Semigroup a => a -> a -> a
<> FilePath
"/" forall a. Semigroup a => a -> a -> a
<> FilePath
file)
  where
    decompress' :: ByteString -> ByteString
decompress' = [ByteString] -> ByteString
BS.concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BS.Lazy.toChunks forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
GZip.decompress forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BS.Lazy.fromStrict

initMnist :: String -> IO (MnistData, MnistData)
initMnist :: FilePath -> IO (MnistData, MnistData)
initMnist FilePath
path = do
  ByteString
imagesBS <- FilePath -> FilePath -> IO ByteString
decompressFile FilePath
path FilePath
"train-images-idx3-ubyte.gz"
  ByteString
labelsBS <- FilePath -> FilePath -> IO ByteString
decompressFile FilePath
path FilePath
"train-labels-idx1-ubyte.gz"
  ByteString
testImagesBS <- FilePath -> FilePath -> IO ByteString
decompressFile FilePath
path FilePath
"t10k-images-idx3-ubyte.gz"
  ByteString
testLabelsBS <- FilePath -> FilePath -> IO ByteString
decompressFile FilePath
path FilePath
"t10k-labels-idx1-ubyte.gz"
  forall (m :: Type -> Type) a. Monad m => a -> m a
return (ByteString -> ByteString -> MnistData
MnistData ByteString
imagesBS ByteString
labelsBS, ByteString -> ByteString -> MnistData
MnistData ByteString
testImagesBS ByteString
testLabelsBS)