module Torch.Serialize where

import Control.Exception.Safe
  ( SomeException (..),
    throwIO,
    try,
  )
import Control.Monad (when)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Ptr as F
import System.IO
import Torch.Autograd
import Torch.DType
import Torch.Functional
import Torch.Internal.Cast
import qualified Torch.Internal.Managed.Serialize as S
import Torch.NN
import Torch.Script hiding (clone, load, save)
import Torch.Tensor

save ::
  -- | inputs
  [Tensor] ->
  -- | file
  FilePath ->
  -- | output
  IO ()
save :: [Tensor] -> FilePath -> IO ()
save = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorList -> FilePath -> IO ()
S.save

load ::
  -- | file
  FilePath ->
  -- | output
  IO [Tensor]
load :: FilePath -> IO [Tensor]
load = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr TensorList)
S.load

-- | Save state_dict
pickleSave ::
  -- | inputs
  IValue ->
  -- | file
  FilePath ->
  -- | output
  IO ()
pickleSave :: IValue -> FilePath -> IO ()
pickleSave = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IValue -> FilePath -> IO ()
S.pickleSave

-- | Load a state_dict file
-- You should use a dict function of pytorch to save a state_dict file as follows.
--
-- > torch.save(dict(model.state_dict()), "state_dict.pth")
pickleLoad ::
  -- | file
  FilePath ->
  -- | output
  IO IValue
pickleLoad :: FilePath -> IO IValue
pickleLoad = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr IValue)
S.pickleLoad

saveParams ::
  Parameterized f =>
  -- | model
  f ->
  -- | filepath
  FilePath ->
  -- | output
  IO ()
saveParams :: forall f. Parameterized f => f -> FilePath -> IO ()
saveParams f
model FilePath
filePath = do
  let params :: [Tensor]
params = forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters f
model
  [Tensor] -> FilePath -> IO ()
save [Tensor]
params FilePath
filePath

loadParams ::
  Parameterized b =>
  -- | model
  b ->
  -- | filepath
  FilePath ->
  -- | output
  IO b
loadParams :: forall b. Parameterized b => b -> FilePath -> IO b
loadParams b
model FilePath
filePath = do
  [Tensor]
tensors <- FilePath -> IO [Tensor]
load FilePath
filePath
  let params :: [IndependentTensor]
params = forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor [Tensor]
tensors
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [IndependentTensor] -> f
replaceParameters b
model [IndependentTensor]
params

class RawFile a where
  loadBinary :: Handle -> a -> IO a
  saveBinary :: Handle -> a -> IO ()

instance RawFile Tensor where
  loadBinary :: Handle -> Tensor -> IO Tensor
loadBinary Handle
handle Tensor
tensor = do
    let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
    ByteString
v <- Handle -> Int -> IO ByteString
BS.hGet Handle
handle Int
len
    Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
    forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
      let (BSI.PS ForeignPtr Word8
fptr Int
_ Int
len') = ByteString
v
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len' forall a. Ord a => a -> a -> Bool
< Int
len) forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ FilePath -> IOError
userError forall a b. (a -> b) -> a -> b
$ FilePath
"Read data's size is less than input tensor's one(" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
len forall a. Semigroup a => a -> a -> a
<> FilePath
")."
      forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr2 -> do
        Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BSI.memcpy (forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
ptr2) (forall a. Ord a => a -> a -> a
Prelude.min Int
len Int
len')
        forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  saveBinary :: Handle -> Tensor -> IO ()
saveBinary Handle
handle Tensor
tensor = do
    let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
    Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
    forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
tensor forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
      forall a. Handle -> Ptr a -> Int -> IO ()
hPutBuf Handle
handle (forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) Int
len