{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

module Torch.Typed.Serialize where

import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Serialize as S
import qualified Torch.Internal.Type as ATen
import qualified Torch.Tensor as D
import Torch.Typed.Tensor

-- | save list of tensors to file
save ::
  forall tensors.
  ATen.Castable (HList tensors) [D.ATenTensor] =>
  -- | list of input tensors
  HList tensors ->
  -- | file
  FilePath ->
  IO ()
save :: forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
HList tensors -> FilePath -> IO ()
save = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> FilePath -> IO ()
S.save

-- | load list of tensors from file
load ::
  forall tensors.
  ATen.Castable (HList tensors) [D.ATenTensor] =>
  -- | file
  FilePath ->
  IO (HList tensors)
load :: forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
FilePath -> IO (HList tensors)
load = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 FilePath -> IO (ForeignPtr TensorList)
S.load