{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}

module Torch.Internal.Unmanaged.Serialize where

import Foreign.Ptr
import Foreign.C.String
import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Cpp.Unsafe as C
import qualified Language.C.Inline.Context as C
import qualified Language.C.Types as C

import Torch.Internal.Type

C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable }

C.include "<vector>"
C.include "<fstream>"
C.include "<torch/serialize.h>"
C.include "<ATen/Tensor.h>"
C.include "<ATen/core/ivalue.h>"

save :: Ptr TensorList -> FilePath -> IO ()
save :: Ptr TensorList -> FilePath -> IO ()
save Ptr TensorList
inputs FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile -> [C.throwBlock| void {
    torch::save(*$(std::vector<at::Tensor>* inputs),$(char* cfile));
  }|]

load :: FilePath -> IO (Ptr TensorList)
load :: FilePath -> IO (Ptr TensorList)
load FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile -> [C.throwBlock| std::vector<at::Tensor>* {
    std::vector<at::Tensor> tensor_vec;                                                
    torch::load(tensor_vec,$(char* cfile));
    return new std::vector<at::Tensor>(tensor_vec);
  }|]

pickleSave :: Ptr IValue -> FilePath -> IO ()
pickleSave :: Ptr IValue -> FilePath -> IO ()
pickleSave Ptr IValue
inputs FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile -> [C.throwBlock| void {
    auto output = torch::pickle_save(*$(at::IValue* inputs));
    auto fout = std::ofstream($(char* cfile), std::ios::out | std::ofstream::binary);
    std::copy(output.begin(), output.end(), std::ostreambuf_iterator<char>(fout));
  }|]

pickleLoad :: FilePath -> IO (Ptr IValue)
pickleLoad :: FilePath -> IO (Ptr IValue)
pickleLoad FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile -> [C.throwBlock| at::IValue* {
    auto fin = std::ifstream($(char* cfile), std::ios::in | std::ifstream::binary);
    const std::vector<char> input = std::vector<char>(std::istreambuf_iterator<char>(fin), std::istreambuf_iterator<char>());
    return new at::IValue(torch::pickle_load(input));
  }|]