module Torch.Serialize where
import Control.Exception.Safe
( SomeException (..),
throwIO,
try,
)
import Control.Monad (when)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Ptr as F
import System.IO
import Torch.Autograd
import Torch.DType
import Torch.Functional
import Torch.Internal.Cast
import qualified Torch.Internal.Managed.Serialize as S
import Torch.NN
import Torch.Script hiding (clone, load, save)
import Torch.Tensor
save ::
[Tensor] ->
FilePath ->
IO ()
save :: [Tensor] -> FilePath -> IO ()
save = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorList -> FilePath -> IO ()
S.save
load ::
FilePath ->
IO [Tensor]
load :: FilePath -> IO [Tensor]
load = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr TensorList)
S.load
pickleSave ::
IValue ->
FilePath ->
IO ()
pickleSave :: IValue -> FilePath -> IO ()
pickleSave = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IValue -> FilePath -> IO ()
S.pickleSave
pickleLoad ::
FilePath ->
IO IValue
pickleLoad :: FilePath -> IO IValue
pickleLoad = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr IValue)
S.pickleLoad
saveParams ::
Parameterized f =>
f ->
FilePath ->
IO ()
saveParams :: forall f. Parameterized f => f -> FilePath -> IO ()
saveParams f
model FilePath
filePath = do
let params :: [Tensor]
params = forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters f
model
[Tensor] -> FilePath -> IO ()
save [Tensor]
params FilePath
filePath
loadParams ::
Parameterized b =>
b ->
FilePath ->
IO b
loadParams :: forall b. Parameterized b => b -> FilePath -> IO b
loadParams b
model FilePath
filePath = do
[Tensor]
tensors <- FilePath -> IO [Tensor]
load FilePath
filePath
let params :: [IndependentTensor]
params = forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor [Tensor]
tensors
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [IndependentTensor] -> f
replaceParameters b
model [IndependentTensor]
params
class RawFile a where
loadBinary :: Handle -> a -> IO a
saveBinary :: Handle -> a -> IO ()
instance RawFile Tensor where
loadBinary :: Handle -> Tensor -> IO Tensor
loadBinary Handle
handle Tensor
tensor = do
let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
ByteString
v <- Handle -> Int -> IO ByteString
BS.hGet Handle
handle Int
len
Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
let (BSI.PS ForeignPtr Word8
fptr Int
_ Int
len') = ByteString
v
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len' forall a. Ord a => a -> a -> Bool
< Int
len) forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ FilePath -> IOError
userError forall a b. (a -> b) -> a -> b
$ FilePath
"Read data's size is less than input tensor's one(" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
len forall a. Semigroup a => a -> a -> a
<> FilePath
")."
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr2 -> do
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BSI.memcpy (forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
ptr2) (forall a. Ord a => a -> a -> a
Prelude.min Int
len Int
len')
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
saveBinary :: Handle -> Tensor -> IO ()
saveBinary Handle
handle Tensor
tensor = do
let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
tensor forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
forall a. Handle -> Ptr a -> Int -> IO ()
hPutBuf Handle
handle (forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) Int
len