{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE NoStarIsType #-} module Torch.Typed.Vision where import qualified Codec.Compression.GZip as GZip import Control.Monad (forM_) import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BSI import qualified Data.ByteString.Lazy as BS.Lazy import Data.Kind import qualified Foreign.ForeignPtr as F import qualified Foreign.Ptr as F import GHC.Exts (IsList (fromList)) import GHC.TypeLits import System.IO.Unsafe import qualified Torch.DType as D import Torch.Data.Pipeline import qualified Torch.Device as D import Torch.Internal.Cast import qualified Torch.Internal.Managed.TensorFactories as LibTorch import qualified Torch.Tensor as D import qualified Torch.TensorOptions as D import Torch.Typed.Auxiliary import Torch.Typed.Functional import Torch.Typed.Tensor data MNIST (m :: Type -> Type) (device :: (D.DeviceType, Nat)) (batchSize :: Nat) = MNIST {forall (m :: Type -> Type) (device :: (DeviceType, Nat)) (batchSize :: Nat). MNIST m device batchSize -> MnistData mnistData :: MnistData} instance (KnownNat batchSize, KnownDevice device, Applicative m) => Dataset m (MNIST m device batchSize) Int (Tensor device 'D.Float '[batchSize, 784], Tensor device 'D.Int64 '[batchSize]) where getItem :: MNIST m device batchSize -> Int -> m (Tensor device 'Float '[batchSize, 784], Tensor device 'Int64 '[batchSize]) getItem MNIST {MnistData mnistData :: MnistData mnistData :: forall (m :: Type -> Type) (device :: (DeviceType, Nat)) (batchSize :: Nat). MNIST m device batchSize -> MnistData ..} Int ix = let batchSize :: Int batchSize = forall (n :: Nat). KnownNat n => Int natValI @batchSize indexes :: [Int] indexes = [Int ix forall a. Num a => a -> a -> a * Int batchSize .. (Int ix forall a. Num a => a -> a -> a + Int 1) forall a. Num a => a -> a -> a * Int batchSize forall a. Num a => a -> a -> a - Int 1] imgs :: CPUTensor 'Float '[batchSize, DataDim] imgs = forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim] getImages @batchSize MnistData mnistData [Int] indexes labels :: CPUTensor 'Int64 '[batchSize] labels = forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Int64 '[n] getLabels @batchSize MnistData mnistData [Int] indexes in forall (f :: Type -> Type) a. Applicative f => a -> f a pure (forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]) t t'. (KnownDevice device', IsUnnamed t device dtype shape, Unnamed t', t' ~ ReplaceDevice'' t device') => t -> t' toDevice @device CPUTensor 'Float '[batchSize, DataDim] imgs, forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]) t t'. (KnownDevice device', IsUnnamed t device dtype shape, Unnamed t', t' ~ ReplaceDevice'' t device') => t -> t' toDevice @device CPUTensor 'Int64 '[batchSize] labels) keys :: MNIST m device batchSize -> Set Int keys MNIST {MnistData mnistData :: MnistData mnistData :: forall (m :: Type -> Type) (device :: (DeviceType, Nat)) (batchSize :: Nat). MNIST m device batchSize -> MnistData ..} = forall l. IsList l => [Item l] -> l fromList [Int 0 .. MnistData -> Int Torch.Typed.Vision.length MnistData mnistData forall a. Integral a => a -> a -> a `Prelude.div` (forall (n :: Nat). KnownNat n => Int natValI @batchSize) forall a. Num a => a -> a -> a - Int 1] data MnistData = MnistData { MnistData -> ByteString images :: BS.ByteString, MnistData -> ByteString labels :: BS.ByteString } type Rows = 28 type Cols = 28 type DataDim = Rows * Cols type ClassDim = 10 getLabels :: forall n. KnownNat n => MnistData -> [Int] -> CPUTensor 'D.Int64 '[n] getLabels :: forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Int64 '[n] getLabels MnistData mnist [Int] imageIdxs = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. TensorLike a => a -> Tensor D.asTensor forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a b. (a -> b) -> [a] -> [b] map (MnistData -> Int -> Int getLabel MnistData mnist) forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. Int -> [a] -> [a] take (forall (n :: Nat). KnownNat n => Int natValI @n) forall a b. (a -> b) -> a -> b $ [Int] imageIdxs getLabel :: MnistData -> Int -> Int getLabel :: MnistData -> Int -> Int getLabel MnistData mnist Int imageIdx = forall a b. (Integral a, Num b) => a -> b fromIntegral forall a b. (a -> b) -> a -> b $ HasCallStack => ByteString -> Int -> Word8 BS.index (MnistData -> ByteString labels MnistData mnist) (forall a b. (Integral a, Num b) => a -> b fromIntegral Int imageIdx forall a. Num a => a -> a -> a + Int 8) getImage :: MnistData -> Int -> CPUTensor 'D.Float '[DataDim] getImage :: MnistData -> Int -> CPUTensor 'Float '[DataDim] getImage MnistData mnist Int imageIdx = let imageBS :: [Float] imageBS = [ forall a b. (Integral a, Num b) => a -> b fromIntegral forall a b. (a -> b) -> a -> b $ HasCallStack => ByteString -> Int -> Word8 BS.index (MnistData -> ByteString images MnistData mnist) (forall a b. (Integral a, Num b) => a -> b fromIntegral Int imageIdx forall a. Num a => a -> a -> a * Int 28 forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 forall a. Num a => a -> a -> a + Int 16 forall a. Num a => a -> a -> a + Int r) | Int r <- [Int 0 .. Int 28 forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 forall a. Num a => a -> a -> a - Int 1] ] :: [Float] (CPUTensor 'Float '[DataDim] tensor :: CPUTensor 'D.Float '[DataDim]) = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor forall a b. (a -> b) -> a -> b $ forall a. TensorLike a => a -> Tensor D.asTensor [Float] imageBS in CPUTensor 'Float '[DataDim] tensor getImages' :: forall n. KnownNat n => MnistData -> [Int] -> CPUTensor 'D.Float '[n, DataDim] getImages' :: forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim] getImages' MnistData mnist [Int] imageIdxs = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor forall a b. (a -> b) -> a -> b $ forall a. TensorLike a => a -> Tensor D.asTensor forall a b. (a -> b) -> a -> b $ forall a b. (a -> b) -> [a] -> [b] map Int -> [Float] image forall a b. (a -> b) -> a -> b $ forall a. Int -> [a] -> [a] take (forall (n :: Nat). KnownNat n => Int natValI @n) [Int] imageIdxs where image :: Int -> [Float] image Int idx = [ forall a b. (Integral a, Num b) => a -> b fromIntegral forall a b. (a -> b) -> a -> b $ HasCallStack => ByteString -> Int -> Word8 BS.index (MnistData -> ByteString images MnistData mnist) (forall a b. (Integral a, Num b) => a -> b fromIntegral Int idx forall a. Num a => a -> a -> a * Int 28 forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 forall a. Num a => a -> a -> a + Int 16 forall a. Num a => a -> a -> a + Int r) | Int r <- [Int 0 .. Int 28 forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 forall a. Num a => a -> a -> a - Int 1] ] :: [Float] getImages :: forall n. KnownNat n => MnistData -> [Int] -> CPUTensor 'D.Float '[n, DataDim] getImages :: forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim] getImages MnistData mnist [Int] imageIdxs = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor forall a b. (a -> b) -> a -> b $ forall a. IO a -> a unsafePerformIO forall a b. (a -> b) -> a -> b $ do let (BSI.PS ForeignPtr Word8 fptr Int off Int len) = MnistData -> ByteString images MnistData mnist Tensor t <- (forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr IntArray -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor) LibTorch.empty_lo :: [Int] -> D.TensorOptions -> IO D.Tensor) [forall (n :: Nat). KnownNat n => Int natValI @n, forall (n :: Nat). KnownNat n => Int natValI @DataDim] (DType -> TensorOptions -> TensorOptions D.withDType DType D.UInt8 TensorOptions D.defaultOpts) forall a. Tensor -> (Ptr () -> IO a) -> IO a D.withTensor Tensor t forall a b. (a -> b) -> a -> b $ \Ptr () ptr1 -> do forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b F.withForeignPtr ForeignPtr Word8 fptr forall a b. (a -> b) -> a -> b $ \Ptr Word8 ptr2 -> do forall (t :: Type -> Type) (m :: Type -> Type) a b. (Foldable t, Monad m) => t a -> (a -> m b) -> m () forM_ (forall a b. [a] -> [b] -> [(a, b)] zip [Int 0 .. ((forall (n :: Nat). KnownNat n => Int natValI @n) forall a. Num a => a -> a -> a -Int 1)] [Int] imageIdxs) forall a b. (a -> b) -> a -> b $ \(Int i, Int idx) -> do Ptr Word8 -> Ptr Word8 -> Int -> IO () BSI.memcpy (forall a b. Ptr a -> Int -> Ptr b F.plusPtr Ptr () ptr1 ((forall (n :: Nat). KnownNat n => Int natValI @DataDim) forall a. Num a => a -> a -> a * Int i)) (forall a b. Ptr a -> Int -> Ptr b F.plusPtr Ptr Word8 ptr2 (Int off forall a. Num a => a -> a -> a + Int 16 forall a. Num a => a -> a -> a + (forall (n :: Nat). KnownNat n => Int natValI @DataDim) forall a. Num a => a -> a -> a * Int idx)) (forall (n :: Nat). KnownNat n => Int natValI @DataDim) forall (m :: Type -> Type) a. Monad m => a -> m a return forall a b. (a -> b) -> a -> b $ forall a. HasTypes a Tensor => DType -> a -> a D.toType DType D.Float Tensor t length :: MnistData -> Int length :: MnistData -> Int length MnistData mnist = forall a b. (Integral a, Num b) => a -> b fromIntegral forall a b. (a -> b) -> a -> b $ ByteString -> Int BS.length (MnistData -> ByteString labels MnistData mnist) forall a. Num a => a -> a -> a - Int 8 decompressFile :: String -> String -> IO BS.ByteString decompressFile :: FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath file = ByteString -> ByteString decompress' forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b <$> FilePath -> IO ByteString BS.readFile (FilePath path forall a. Semigroup a => a -> a -> a <> FilePath "/" forall a. Semigroup a => a -> a -> a <> FilePath file) where decompress' :: ByteString -> ByteString decompress' = [ByteString] -> ByteString BS.concat forall b c a. (b -> c) -> (a -> b) -> a -> c . ByteString -> [ByteString] BS.Lazy.toChunks forall b c a. (b -> c) -> (a -> b) -> a -> c . ByteString -> ByteString GZip.decompress forall b c a. (b -> c) -> (a -> b) -> a -> c . ByteString -> ByteString BS.Lazy.fromStrict initMnist :: String -> IO (MnistData, MnistData) initMnist :: FilePath -> IO (MnistData, MnistData) initMnist FilePath path = do ByteString imagesBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "train-images-idx3-ubyte.gz" ByteString labelsBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "train-labels-idx1-ubyte.gz" ByteString testImagesBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "t10k-images-idx3-ubyte.gz" ByteString testLabelsBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "t10k-labels-idx1-ubyte.gz" forall (m :: Type -> Type) a. Monad m => a -> m a return (ByteString -> ByteString -> MnistData MnistData ByteString imagesBS ByteString labelsBS, ByteString -> ByteString -> MnistData MnistData ByteString testImagesBS ByteString testLabelsBS)