{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}

module Torch.Internal.Unmanaged.Type.Module where

import Control.Exception.Safe (bracket)
import Data.IORef
import qualified Data.Map as Map
import Foreign
import Foreign.C.String
import Foreign.C.Types
import qualified Language.C.Inline.Context as C
import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Cpp.Unsafe as C
import qualified Language.C.Inline.Cpp.Exceptions as Safe
import qualified Language.C.Types as C
import Torch.Internal.Type
import Torch.Internal.Unmanaged.Helper
import Control.Exception.Safe (bracket)
import Control.Monad (forM)

C.context $ C.cppCtx <> mempty {C.ctxTypesTable = typeTable}

C.include "<torch/script.h>"

C.include "<torch/csrc/jit/serialization/export.h>"

C.include "<torch/csrc/jit/frontend/tracer.h>"

C.include "<vector>"

C.include "<iostream>"

-- From libtorch/include/torch/csrc/jit/script/module.h

newModule ::
  Ptr StdString -> IO (Ptr Module)
newModule :: Ptr StdString -> IO (Ptr Module)
newModule Ptr StdString
name =
  [C.throwBlock| torch::jit::script::Module* { return new torch::jit::script::Module(
     *$(std::string* name)
    );
  }|]

save :: Ptr Module -> FilePath -> IO ()
save :: Ptr Module -> FilePath -> IO ()
save Ptr Module
obj FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile ->
  [C.throwBlock| void {
    $(torch::jit::script::Module* obj)->save($(char* cfile));
  }|]

load :: FilePath -> IO (Ptr Module)
load :: FilePath -> IO (Ptr Module)
load FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile ->
  [C.throwBlock| torch::jit::script::Module* {
    return new torch::jit::script::Module(torch::jit::load($(char* cfile)));
  }|]

forward :: Ptr Module -> Ptr (StdVector IValue) -> IO (Ptr IValue)
forward :: Ptr Module -> Ptr (StdVector IValue) -> IO (Ptr IValue)
forward Ptr Module
obj Ptr (StdVector IValue)
inputs =
  [C.throwBlock| at::IValue* {
    return new at::IValue($(torch::jit::script::Module* obj)->forward(*$(std::vector<at::IValue>* inputs)));
  }|]

registerParameter :: Ptr Module -> Ptr StdString -> Ptr Tensor -> CBool -> IO ()
registerParameter :: Ptr Module -> Ptr StdString -> Ptr Tensor -> CBool -> IO ()
registerParameter Ptr Module
obj Ptr StdString
name Ptr Tensor
v CBool
is_buffer =
  [C.throwBlock| void {
    $(torch::jit::script::Module* obj)->register_parameter(
      *$(std::string* name)
    , *$(at::Tensor* v)
    , $(bool is_buffer)
    );
  }|]

registerModule :: Ptr Module -> Ptr StdString -> Ptr Module -> IO ()
registerModule :: Ptr Module -> Ptr StdString -> Ptr Module -> IO ()
registerModule Ptr Module
obj Ptr StdString
name Ptr Module
v =
  [C.throwBlock| void {
    $(torch::jit::script::Module* obj)->register_module(
      *$(std::string* name)
    , *$(torch::jit::script::Module* v)
    );
  }|]

train :: Ptr Module -> CBool -> IO ()
train :: Ptr Module -> CBool -> IO ()
train Ptr Module
obj CBool
on =
  [C.throwBlock| void {
    $(torch::jit::script::Module* obj)->train(
      $(bool on)
    );
  }|]

runMethod :: Ptr Module -> Ptr StdString -> Ptr (C10List IValue) -> IO (Ptr IValue)
runMethod :: Ptr Module
-> Ptr StdString -> Ptr (C10List IValue) -> IO (Ptr IValue)
runMethod Ptr Module
obj Ptr StdString
method_name Ptr (C10List IValue)
args =
  [C.throwBlock| at::IValue* {
    return new at::IValue($(torch::jit::script::Module* obj)->run_method(
      *$(std::string* method_name)
    , *$(c10::List<at::IValue>* args)
    ));
  }|]

runMethod1 :: Ptr Module -> Ptr StdString -> Ptr IValue -> IO (Ptr IValue)
runMethod1 :: Ptr Module -> Ptr StdString -> Ptr IValue -> IO (Ptr IValue)
runMethod1 Ptr Module
obj Ptr StdString
method_name Ptr IValue
args =
  [C.throwBlock| at::IValue* {
    return new at::IValue($(torch::jit::script::Module* obj)->run_method(
      *$(std::string* method_name)
    , *$(at::IValue* args)
    ));
  }|]

getParameters :: Ptr Module -> IO (Ptr TensorList)
getParameters :: Ptr Module -> IO (Ptr TensorList)
getParameters Ptr Module
obj =
  [C.throwBlock| std::vector<at::Tensor>* {
    std::vector<at::Tensor>* vec_parameters = new std::vector<at::Tensor>();
    auto parameters = $(torch::jit::script::Module* obj)->parameters();
    for(auto p : parameters) {
      vec_parameters->push_back(p);
    }
    return vec_parameters;
  }|]

setParameters :: Ptr Module -> Ptr TensorList -> IO ()
setParameters :: Ptr Module -> Ptr TensorList -> IO ()
setParameters Ptr Module
obj Ptr TensorList
params =
  [C.throwBlock| void {
    auto module = $(torch::jit::script::Module* obj);
    auto parameters = module->named_parameters();
    auto vec = $(std::vector<at::Tensor>* params);
    int i=0; 
    for(auto p : parameters) {
      module->register_parameter(p.name,(*vec)[i],false);
    }
  }|]

getNamedParameters :: Ptr Module -> IO [(Ptr StdString,Ptr Tensor)]
getNamedParameters :: Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
getNamedParameters Ptr Module
_obj = do
  let new :: IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new = [C.throwBlock| std::vector<std::tuple<std::string,at::Tensor>>* {
              auto module = $(torch::jit::script::Module* _obj);
              auto obj = module->named_parameters();
              auto ret = new std::vector<std::tuple<std::string,at::Tensor>>();
              for(auto p : obj){
                ret->push_back({p.name,p.value});
              }
              return ret;
             }|]
      free :: Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat = [C.throwBlock| void {
              delete $(std::vector<std::tuple<std::string,at::Tensor>>* dat);
             }|]
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat -> do
    Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,at::Tensor>>* dat)->size();}|]
    [(Ptr StdString, Ptr Tensor)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
      Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
      Ptr Tensor
val <- [C.throwBlock| at::Tensor* { return new at::Tensor(std::get<1>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
      forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Tensor
val)
    forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Tensor)]
ret

getNamedBuffers :: Ptr Module -> IO [(Ptr StdString,Ptr Tensor)]
getNamedBuffers :: Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
getNamedBuffers Ptr Module
_obj = do
  let new :: IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new = [C.throwBlock| std::vector<std::tuple<std::string,at::Tensor>>* {
              auto module = $(torch::jit::script::Module* _obj);
              auto obj = module->named_buffers();
              auto ret = new std::vector<std::tuple<std::string,at::Tensor>>();
              for(auto p : obj){
                ret->push_back({p.name,p.value});
              }
              return ret;
             }|]
      free :: Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat = [C.throwBlock| void {
              delete $(std::vector<std::tuple<std::string,at::Tensor>>* dat);
             }|]
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat -> do
    Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,at::Tensor>>* dat)->size();}|]
    [(Ptr StdString, Ptr Tensor)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
      Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
      Ptr Tensor
val <- [C.throwBlock| at::Tensor* { return new at::Tensor(std::get<1>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
      forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Tensor
val)
    forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Tensor)]
ret

getNamedAttributes :: Ptr Module -> IO [(Ptr StdString,Ptr IValue)]
getNamedAttributes :: Ptr Module -> IO [(Ptr StdString, Ptr IValue)]
getNamedAttributes Ptr Module
_obj = do
  let new :: IO (Ptr (StdVector (StdTuple '(StdString, IValue))))
new = [C.throwBlock| std::vector<std::tuple<std::string,at::IValue>>* {
              auto module = $(torch::jit::script::Module* _obj);
              auto obj = module->named_attributes();
              auto ret = new std::vector<std::tuple<std::string,at::IValue>>();
              for(auto p : obj){
                ret->push_back({p.name,p.value});
              }
              return ret;
             }|]
      free :: Ptr (StdVector (StdTuple '(StdString, IValue))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, IValue)))
dat = [C.throwBlock| void {
              delete $(std::vector<std::tuple<std::string,at::IValue>>* dat);
             }|]
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, IValue))))
new Ptr (StdVector (StdTuple '(StdString, IValue))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, IValue)))
dat -> do
    Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,at::IValue>>* dat)->size();}|]
    [(Ptr StdString, Ptr IValue)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
      Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,at::IValue>>* dat)->at($(int64_t i))));}|]
      Ptr IValue
val <- [C.throwBlock| at::IValue* { return new at::IValue(std::get<1>($(std::vector<std::tuple<std::string,at::IValue>>* dat)->at($(int64_t i))));}|]
      forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr IValue
val)
    forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr IValue)]
ret

getNamedModules :: Ptr Module -> IO [(Ptr StdString,Ptr Module)]
getNamedModules :: Ptr Module -> IO [(Ptr StdString, Ptr Module)]
getNamedModules Ptr Module
_obj = do
  let new :: IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new = [C.throwBlock| std::vector<std::tuple<std::string,torch::jit::script::Module>>* {
              auto module = $(torch::jit::script::Module* _obj);
              auto obj = module->named_modules();
              auto ret = new std::vector<std::tuple<std::string,torch::jit::script::Module>>();
              for(auto p : obj){
                ret->push_back({p.name,p.value});
              }
              return ret;
             }|]
      free :: Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Module)))
dat = [C.throwBlock| void {
              delete $(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat);
             }|]
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Module)))
dat -> do
    Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->size();}|]
    [(Ptr StdString, Ptr Module)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
      Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
      Ptr Module
val <- [C.throwBlock| torch::jit::script::Module* { return new torch::jit::script::Module(std::get<1>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
      forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Module
val)
    forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Module)]
ret

getNamedChildren :: Ptr Module -> IO [(Ptr StdString,Ptr Module)]
getNamedChildren :: Ptr Module -> IO [(Ptr StdString, Ptr Module)]
getNamedChildren Ptr Module
_obj = do
  let new :: IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new = [C.throwBlock| std::vector<std::tuple<std::string,torch::jit::script::Module>>* {
              auto module = $(torch::jit::script::Module* _obj);
              auto obj = module->named_children();
              auto ret = new std::vector<std::tuple<std::string,torch::jit::script::Module>>();
              for(auto p : obj){
                ret->push_back({p.name,p.value});
              }
              return ret;
             }|]
      free :: Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Module)))
dat = [C.throwBlock| void {
              delete $(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat);
             }|]
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Module)))
dat -> do
    Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->size();}|]
    [(Ptr StdString, Ptr Module)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
      Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
      Ptr Module
val <- [C.throwBlock| torch::jit::script::Module* { return new torch::jit::script::Module(std::get<1>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
      forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Module
val)
    forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Module)]
ret


toDevice :: Ptr Module -> DeviceType -> Int16 -> IO ()
toDevice :: Ptr Module -> DeviceType -> DeviceType -> IO ()
toDevice Ptr Module
obj DeviceType
device DeviceType
device_index =
  [C.throwBlock| void {
    $(torch::jit::script::Module* obj)->to(torch::Device($(at::DeviceType device), $(int16_t device_index)));
  }|]

clone :: Ptr Module -> IO (Ptr Module)
clone :: Ptr Module -> IO (Ptr Module)
clone Ptr Module
obj =
  [C.throwBlock| torch::jit::script::Module* {
    return new torch::jit::script::Module($(torch::jit::script::Module* obj)->clone());
  }|]

define :: Ptr Module -> Ptr StdString -> IO ()
define :: Ptr Module -> Ptr StdString -> IO ()
define Ptr Module
obj Ptr StdString
src =
  [C.throwBlock| void {
    $(torch::jit::script::Module* obj)->define(
      *$(std::string* src)
    );
  }|]

trace :: CString -> CString -> (Ptr TensorList -> IO (Ptr TensorList)) -> Ptr TensorList -> IO (Ptr Module)
trace :: CString
-> CString
-> (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList
-> IO (Ptr Module)
trace CString
moduleName CString
functionName Ptr TensorList -> IO (Ptr TensorList)
func Ptr TensorList
inputs =
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    ((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TensorList -> IO (Ptr TensorList)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
    forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
      [Safe.throwBlock| torch::jit::script::Module* {
        torch::jit::script::Module self($(char* moduleName));
        auto vars_in = *$(std::vector<at::Tensor>* inputs);
        auto tfunc = $(void* (*funcPtr)(void*));
        typedef std::vector<at::Tensor>* (*Func)(std::vector<at::Tensor>*);
        auto func = (Func)tfunc;
        auto graph = torch::jit::tracer::trace(
          c10::fmap<c10::IValue>(vars_in),
          [&func](c10::Stack in) -> c10::Stack {
            std::vector<at::Tensor>* ivalue_inps = new std::vector<at::Tensor>(c10::fmap(in, [](const c10::IValue& v){
              return torch::autograd::Variable(v.toTensor());
            }));
            std::vector<at::Tensor> out = *(func(ivalue_inps));
            return c10::fmap<c10::IValue>(out);
          },
          [](const torch::autograd::Variable& var) { return "";}
        ).first->graph;
        auto v = graph->insertInput(0, "self");
        v->setType(self._ivalue()->type());
        const auto name = c10::QualifiedName(*self.type()->name(), $(char* functionName));
        auto fn2 = self._ivalue()->compilation_unit()->create_function(name,graph);
        self.type()->addMethod(fn2);
        return new torch::jit::script::Module(self);
      }|]

traceAsGraph :: (Ptr TensorList -> IO (Ptr TensorList)) -> Ptr TensorList -> IO (Ptr (SharedPtr JitGraph))
traceAsGraph :: (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList -> IO (Ptr (SharedPtr JitGraph))
traceAsGraph Ptr TensorList -> IO (Ptr TensorList)
func Ptr TensorList
inputs =
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    ((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TensorList -> IO (Ptr TensorList)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
    forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
      [Safe.throwBlock| std::shared_ptr<torch::jit::Graph>* {
        torch::jit::script::Module self("MyModule");
        auto vars_in = *$(std::vector<at::Tensor>* inputs);
        auto tfunc = $(void* (*funcPtr)(void*));
        typedef std::vector<at::Tensor>* (*Func)(std::vector<at::Tensor>*);
        auto func = (Func)tfunc;
        auto graph = torch::jit::tracer::trace(
          c10::fmap<c10::IValue>(vars_in),
          [&func](c10::Stack in) -> c10::Stack {
            std::vector<at::Tensor>* ivalue_inps = new std::vector<at::Tensor>(c10::fmap(in, [](const c10::IValue& v){
              return torch::autograd::Variable(v.toTensor());
            }));
            std::vector<at::Tensor> out = *(func(ivalue_inps));
            return c10::fmap<c10::IValue>(out);
          },
          [](const torch::autograd::Variable& var) { return "";}
        ).first->graph;
        return new std::shared_ptr<torch::jit::Graph>(graph);
      }|]

withJitGraph :: Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
withJitGraph :: forall a.
Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
withJitGraph Ptr (SharedPtr JitGraph)
graph Ptr JitGraph -> IO a
callback = do
  Ptr JitGraph
v <-
    [C.throwBlock| torch::jit::Graph* {
         return (*$(std::shared_ptr<torch::jit::Graph>* graph)).get();
       }|]
  Ptr JitGraph -> IO a
callback Ptr JitGraph
v

graphOutputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphOutputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphOutputs Ptr JitGraph
graph = do
  IORef [Ptr JitValue]
nodes <- forall a. a -> IO (IORef a)
newIORef []
  let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
        [Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes -- TODO: Combine with -:242:9 & -:263:9
        forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
nodes (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
        forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    ((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
    forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
      [Safe.throwBlock| void {
        auto tfunc = $(void* (*funcPtr)(void*));
        typedef torch::jit::Value* (*Func)(torch::jit::Value*);
        auto func = (Func)tfunc;
        for(auto i : (*$(torch::jit::Graph* graph)).outputs()){
          func(i);
        }
      }|]
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes

graphInputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphInputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphInputs Ptr JitGraph
graph = do
  IORef [Ptr JitValue]
nodes <- forall a. a -> IO (IORef a)
newIORef []
  let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
        [Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes
        forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
nodes (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
        forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    ((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
    forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
      [Safe.throwBlock| void {
        auto tfunc = $(void* (*funcPtr)(void*));
        typedef torch::jit::Value* (*Func)(torch::jit::Value*);
        auto func = (Func)tfunc;
        for(auto i : (*$(torch::jit::Graph* graph)).inputs()){
          func(i);
        }
      }|]
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes

graphNodes :: Ptr JitGraph -> IO [Ptr JitNode]
graphNodes :: Ptr JitGraph -> IO [Ptr JitNode]
graphNodes Ptr JitGraph
graph = do
  IORef [Ptr JitNode]
nodes <- forall a. a -> IO (IORef a)
newIORef []
  let func :: Ptr JitNode -> IO (Ptr JitNode)
func Ptr JitNode
v = do
        [Ptr JitNode]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitNode]
nodes
        forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitNode]
nodes (Ptr JitNode
v forall a. a -> [a] -> [a]
: [Ptr JitNode]
r)
        forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitNode
v
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    ((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitNode -> IO (Ptr JitNode)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
    forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
      [Safe.throwBlock| void {
        auto tfunc = $(void* (*funcPtr)(void*));
        typedef torch::jit::Node* (*Func)(torch::jit::Node*);
        auto func = (Func)tfunc;
        for(auto i : (*$(torch::jit::Graph* graph)).block()->nodes()){
          func(i);
        }
      }|]
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitNode]
nodes

nodeInputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeInputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeInputs Ptr JitNode
node = do
  IORef [Ptr JitValue]
values <- forall a. a -> IO (IORef a)
newIORef []
  let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
        [Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values -- TODO: Combine with -:305:9
        forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
values (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
        forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    ((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
    forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
      [Safe.throwBlock| void {
        auto tfunc = $(void* (*funcPtr)(void*));
        typedef torch::jit::Value* (*Func)(torch::jit::Value*);
        auto func = (Func)tfunc;
        for(auto i : (*$(torch::jit::Node* node)).inputs()){
          func(i);
        }
      }|]
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values

nodeOutputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeOutputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeOutputs Ptr JitNode
node = do
  IORef [Ptr JitValue]
values <- forall a. a -> IO (IORef a)
newIORef []
  let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
        [Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values
        forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
values (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
        forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
  forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    ((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
    forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
      [Safe.throwBlock| void {
        auto tfunc = $(void* (*funcPtr)(void*));
        typedef torch::jit::Value* (*Func)(torch::jit::Value*);
        auto func = (Func)tfunc;
        for(auto i : (*$(torch::jit::Node* node)).outputs()){
          func(i);
        }
      }|]
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values

nodeKind :: Ptr JitNode -> IO (Ptr StdString)
nodeKind :: Ptr JitNode -> IO (Ptr StdString)
nodeKind Ptr JitNode
node =
  [C.throwBlock| std::string* {
    return new std::string((*$(torch::jit::Node* node)).kind().toQualString());
  }|]

valueId :: Ptr JitValue -> IO CInt
valueId :: Ptr JitValue -> IO CInt
valueId Ptr JitValue
value =
  [C.throwBlock| int {
    return (*$(torch::jit::Value* value)).unique();
  }|]

valueType :: Ptr JitValue -> IO (Ptr StdString)
valueType :: Ptr JitValue -> IO (Ptr StdString)
valueType Ptr JitValue
node =
  [C.throwBlock| std::string* {
    return new std::string((*$(torch::jit::Value* node)).type()->str());
  }|]

printGraph :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printGraph :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printGraph Ptr (SharedPtr JitGraph)
graph =
  [C.throwBlock| std::string* {
    return new std::string((**$(std::shared_ptr<torch::jit::Graph>* graph)).toString());
  }|]

printOnnx :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printOnnx :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printOnnx Ptr (SharedPtr JitGraph)
graph =
  [C.throwBlock| std::string* {
    auto graph_str = torch::jit::pretty_print_onnx(
        *$(std::shared_ptr<torch::jit::Graph>* graph),
        std::map<std::string, at::Tensor>{},
        9,
        false);
    return new std::string(graph_str);
  }|]

dumpToStr ::
  Ptr Module ->
  CBool ->
  CBool ->
  CBool ->
  IO (Ptr StdString)
dumpToStr :: Ptr Module -> CBool -> CBool -> CBool -> IO (Ptr StdString)
dumpToStr Ptr Module
obj CBool
print_method_bodies CBool
print_attr_values CBool
print_param_values =
  [C.throwBlock| std::string* {
    return new std::string($(torch::jit::script::Module* obj)->dump_to_str(
      $(bool print_method_bodies)
    , $(bool print_attr_values)
    , $(bool print_param_values)
    ));
  }|]