{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
module Torch.Internal.Unmanaged.Type.Module where
import Control.Exception.Safe (bracket)
import Data.IORef
import qualified Data.Map as Map
import Foreign
import Foreign.C.String
import Foreign.C.Types
import qualified Language.C.Inline.Context as C
import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Cpp.Unsafe as C
import qualified Language.C.Inline.Cpp.Exceptions as Safe
import qualified Language.C.Types as C
import Torch.Internal.Type
import Torch.Internal.Unmanaged.Helper
import Control.Exception.Safe (bracket)
import Control.Monad (forM)
C.context $ C.cppCtx <> mempty {C.ctxTypesTable = typeTable}
C.include "<torch/script.h>"
C.include "<torch/csrc/jit/serialization/export.h>"
C.include "<torch/csrc/jit/frontend/tracer.h>"
C.include "<vector>"
C.include "<iostream>"
newModule ::
Ptr StdString -> IO (Ptr Module)
newModule :: Ptr StdString -> IO (Ptr Module)
newModule Ptr StdString
name =
[C.throwBlock| torch::jit::script::Module* { return new torch::jit::script::Module(
*$(std::string* name)
);
}|]
save :: Ptr Module -> FilePath -> IO ()
save :: Ptr Module -> FilePath -> IO ()
save Ptr Module
obj FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile ->
[C.throwBlock| void {
$(torch::jit::script::Module* obj)->save($(char* cfile));
}|]
load :: FilePath -> IO (Ptr Module)
load :: FilePath -> IO (Ptr Module)
load FilePath
file = forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
file forall a b. (a -> b) -> a -> b
$ \CString
cfile ->
[C.throwBlock| torch::jit::script::Module* {
return new torch::jit::script::Module(torch::jit::load($(char* cfile)));
}|]
forward :: Ptr Module -> Ptr (StdVector IValue) -> IO (Ptr IValue)
forward :: Ptr Module -> Ptr (StdVector IValue) -> IO (Ptr IValue)
forward Ptr Module
obj Ptr (StdVector IValue)
inputs =
[C.throwBlock| at::IValue* {
return new at::IValue($(torch::jit::script::Module* obj)->forward(*$(std::vector<at::IValue>* inputs)));
}|]
registerParameter :: Ptr Module -> Ptr StdString -> Ptr Tensor -> CBool -> IO ()
registerParameter :: Ptr Module -> Ptr StdString -> Ptr Tensor -> CBool -> IO ()
registerParameter Ptr Module
obj Ptr StdString
name Ptr Tensor
v CBool
is_buffer =
[C.throwBlock| void {
$(torch::jit::script::Module* obj)->register_parameter(
*$(std::string* name)
, *$(at::Tensor* v)
, $(bool is_buffer)
);
}|]
registerModule :: Ptr Module -> Ptr StdString -> Ptr Module -> IO ()
registerModule :: Ptr Module -> Ptr StdString -> Ptr Module -> IO ()
registerModule Ptr Module
obj Ptr StdString
name Ptr Module
v =
[C.throwBlock| void {
$(torch::jit::script::Module* obj)->register_module(
*$(std::string* name)
, *$(torch::jit::script::Module* v)
);
}|]
train :: Ptr Module -> CBool -> IO ()
train :: Ptr Module -> CBool -> IO ()
train Ptr Module
obj CBool
on =
[C.throwBlock| void {
$(torch::jit::script::Module* obj)->train(
$(bool on)
);
}|]
runMethod :: Ptr Module -> Ptr StdString -> Ptr (C10List IValue) -> IO (Ptr IValue)
runMethod :: Ptr Module
-> Ptr StdString -> Ptr (C10List IValue) -> IO (Ptr IValue)
runMethod Ptr Module
obj Ptr StdString
method_name Ptr (C10List IValue)
args =
[C.throwBlock| at::IValue* {
return new at::IValue($(torch::jit::script::Module* obj)->run_method(
*$(std::string* method_name)
, *$(c10::List<at::IValue>* args)
));
}|]
runMethod1 :: Ptr Module -> Ptr StdString -> Ptr IValue -> IO (Ptr IValue)
runMethod1 :: Ptr Module -> Ptr StdString -> Ptr IValue -> IO (Ptr IValue)
runMethod1 Ptr Module
obj Ptr StdString
method_name Ptr IValue
args =
[C.throwBlock| at::IValue* {
return new at::IValue($(torch::jit::script::Module* obj)->run_method(
*$(std::string* method_name)
, *$(at::IValue* args)
));
}|]
getParameters :: Ptr Module -> IO (Ptr TensorList)
getParameters :: Ptr Module -> IO (Ptr TensorList)
getParameters Ptr Module
obj =
[C.throwBlock| std::vector<at::Tensor>* {
std::vector<at::Tensor>* vec_parameters = new std::vector<at::Tensor>();
auto parameters = $(torch::jit::script::Module* obj)->parameters();
for(auto p : parameters) {
vec_parameters->push_back(p);
}
return vec_parameters;
}|]
setParameters :: Ptr Module -> Ptr TensorList -> IO ()
setParameters :: Ptr Module -> Ptr TensorList -> IO ()
setParameters Ptr Module
obj Ptr TensorList
params =
[C.throwBlock| void {
auto module = $(torch::jit::script::Module* obj);
auto parameters = module->named_parameters();
auto vec = $(std::vector<at::Tensor>* params);
int i=0;
for(auto p : parameters) {
module->register_parameter(p.name,(*vec)[i],false);
}
}|]
getNamedParameters :: Ptr Module -> IO [(Ptr StdString,Ptr Tensor)]
getNamedParameters :: Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
getNamedParameters Ptr Module
_obj = do
let new :: IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new = [C.throwBlock| std::vector<std::tuple<std::string,at::Tensor>>* {
auto module = $(torch::jit::script::Module* _obj);
auto obj = module->named_parameters();
auto ret = new std::vector<std::tuple<std::string,at::Tensor>>();
for(auto p : obj){
ret->push_back({p.name,p.value});
}
return ret;
}|]
free :: Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat = [C.throwBlock| void {
delete $(std::vector<std::tuple<std::string,at::Tensor>>* dat);
}|]
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat -> do
Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,at::Tensor>>* dat)->size();}|]
[(Ptr StdString, Ptr Tensor)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
Ptr Tensor
val <- [C.throwBlock| at::Tensor* { return new at::Tensor(std::get<1>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Tensor
val)
forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Tensor)]
ret
getNamedBuffers :: Ptr Module -> IO [(Ptr StdString,Ptr Tensor)]
getNamedBuffers :: Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
getNamedBuffers Ptr Module
_obj = do
let new :: IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new = [C.throwBlock| std::vector<std::tuple<std::string,at::Tensor>>* {
auto module = $(torch::jit::script::Module* _obj);
auto obj = module->named_buffers();
auto ret = new std::vector<std::tuple<std::string,at::Tensor>>();
for(auto p : obj){
ret->push_back({p.name,p.value});
}
return ret;
}|]
free :: Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat = [C.throwBlock| void {
delete $(std::vector<std::tuple<std::string,at::Tensor>>* dat);
}|]
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Tensor))))
new Ptr (StdVector (StdTuple '(StdString, Tensor))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Tensor)))
dat -> do
Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,at::Tensor>>* dat)->size();}|]
[(Ptr StdString, Ptr Tensor)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
Ptr Tensor
val <- [C.throwBlock| at::Tensor* { return new at::Tensor(std::get<1>($(std::vector<std::tuple<std::string,at::Tensor>>* dat)->at($(int64_t i))));}|]
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Tensor
val)
forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Tensor)]
ret
getNamedAttributes :: Ptr Module -> IO [(Ptr StdString,Ptr IValue)]
getNamedAttributes :: Ptr Module -> IO [(Ptr StdString, Ptr IValue)]
getNamedAttributes Ptr Module
_obj = do
let new :: IO (Ptr (StdVector (StdTuple '(StdString, IValue))))
new = [C.throwBlock| std::vector<std::tuple<std::string,at::IValue>>* {
auto module = $(torch::jit::script::Module* _obj);
auto obj = module->named_attributes();
auto ret = new std::vector<std::tuple<std::string,at::IValue>>();
for(auto p : obj){
ret->push_back({p.name,p.value});
}
return ret;
}|]
free :: Ptr (StdVector (StdTuple '(StdString, IValue))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, IValue)))
dat = [C.throwBlock| void {
delete $(std::vector<std::tuple<std::string,at::IValue>>* dat);
}|]
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, IValue))))
new Ptr (StdVector (StdTuple '(StdString, IValue))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, IValue)))
dat -> do
Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,at::IValue>>* dat)->size();}|]
[(Ptr StdString, Ptr IValue)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,at::IValue>>* dat)->at($(int64_t i))));}|]
Ptr IValue
val <- [C.throwBlock| at::IValue* { return new at::IValue(std::get<1>($(std::vector<std::tuple<std::string,at::IValue>>* dat)->at($(int64_t i))));}|]
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr IValue
val)
forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr IValue)]
ret
getNamedModules :: Ptr Module -> IO [(Ptr StdString,Ptr Module)]
getNamedModules :: Ptr Module -> IO [(Ptr StdString, Ptr Module)]
getNamedModules Ptr Module
_obj = do
let new :: IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new = [C.throwBlock| std::vector<std::tuple<std::string,torch::jit::script::Module>>* {
auto module = $(torch::jit::script::Module* _obj);
auto obj = module->named_modules();
auto ret = new std::vector<std::tuple<std::string,torch::jit::script::Module>>();
for(auto p : obj){
ret->push_back({p.name,p.value});
}
return ret;
}|]
free :: Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Module)))
dat = [C.throwBlock| void {
delete $(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat);
}|]
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Module)))
dat -> do
Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->size();}|]
[(Ptr StdString, Ptr Module)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
Ptr Module
val <- [C.throwBlock| torch::jit::script::Module* { return new torch::jit::script::Module(std::get<1>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Module
val)
forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Module)]
ret
getNamedChildren :: Ptr Module -> IO [(Ptr StdString,Ptr Module)]
getNamedChildren :: Ptr Module -> IO [(Ptr StdString, Ptr Module)]
getNamedChildren Ptr Module
_obj = do
let new :: IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new = [C.throwBlock| std::vector<std::tuple<std::string,torch::jit::script::Module>>* {
auto module = $(torch::jit::script::Module* _obj);
auto obj = module->named_children();
auto ret = new std::vector<std::tuple<std::string,torch::jit::script::Module>>();
for(auto p : obj){
ret->push_back({p.name,p.value});
}
return ret;
}|]
free :: Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free Ptr (StdVector (StdTuple '(StdString, Module)))
dat = [C.throwBlock| void {
delete $(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat);
}|]
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket IO (Ptr (StdVector (StdTuple '(StdString, Module))))
new Ptr (StdVector (StdTuple '(StdString, Module))) -> IO ()
free forall a b. (a -> b) -> a -> b
$ \Ptr (StdVector (StdTuple '(StdString, Module)))
dat -> do
Int64
size <- [C.throwBlock| int64_t { return (long int)$(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->size();}|]
[(Ptr StdString, Ptr Module)]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0..(Int64
sizeforall a. Num a => a -> a -> a
-Int64
1)] forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
Ptr StdString
key <- [C.throwBlock| std::string* { return new std::string(std::get<0>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
Ptr Module
val <- [C.throwBlock| torch::jit::script::Module* { return new torch::jit::script::Module(std::get<1>($(std::vector<std::tuple<std::string,torch::jit::script::Module>>* dat)->at($(int64_t i))));}|]
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr StdString
key,Ptr Module
val)
forall (m :: * -> *) a. Monad m => a -> m a
return [(Ptr StdString, Ptr Module)]
ret
toDevice :: Ptr Module -> DeviceType -> Int16 -> IO ()
toDevice :: Ptr Module -> DeviceType -> DeviceType -> IO ()
toDevice Ptr Module
obj DeviceType
device DeviceType
device_index =
[C.throwBlock| void {
$(torch::jit::script::Module* obj)->to(torch::Device($(at::DeviceType device), $(int16_t device_index)));
}|]
clone :: Ptr Module -> IO (Ptr Module)
clone :: Ptr Module -> IO (Ptr Module)
clone Ptr Module
obj =
[C.throwBlock| torch::jit::script::Module* {
return new torch::jit::script::Module($(torch::jit::script::Module* obj)->clone());
}|]
define :: Ptr Module -> Ptr StdString -> IO ()
define :: Ptr Module -> Ptr StdString -> IO ()
define Ptr Module
obj Ptr StdString
src =
[C.throwBlock| void {
$(torch::jit::script::Module* obj)->define(
*$(std::string* src)
);
}|]
trace :: CString -> CString -> (Ptr TensorList -> IO (Ptr TensorList)) -> Ptr TensorList -> IO (Ptr Module)
trace :: CString
-> CString
-> (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList
-> IO (Ptr Module)
trace CString
moduleName CString
functionName Ptr TensorList -> IO (Ptr TensorList)
func Ptr TensorList
inputs =
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TensorList -> IO (Ptr TensorList)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[Safe.throwBlock| torch::jit::script::Module* {
torch::jit::script::Module self($(char* moduleName));
auto vars_in = *$(std::vector<at::Tensor>* inputs);
auto tfunc = $(void* (*funcPtr)(void*));
typedef std::vector<at::Tensor>* (*Func)(std::vector<at::Tensor>*);
auto func = (Func)tfunc;
auto graph = torch::jit::tracer::trace(
c10::fmap<c10::IValue>(vars_in),
[&func](c10::Stack in) -> c10::Stack {
std::vector<at::Tensor>* ivalue_inps = new std::vector<at::Tensor>(c10::fmap(in, [](const c10::IValue& v){
return torch::autograd::Variable(v.toTensor());
}));
std::vector<at::Tensor> out = *(func(ivalue_inps));
return c10::fmap<c10::IValue>(out);
},
[](const torch::autograd::Variable& var) { return "";}
).first->graph;
auto v = graph->insertInput(0, "self");
v->setType(self._ivalue()->type());
const auto name = c10::QualifiedName(*self.type()->name(), $(char* functionName));
auto fn2 = self._ivalue()->compilation_unit()->create_function(name,graph);
self.type()->addMethod(fn2);
return new torch::jit::script::Module(self);
}|]
traceAsGraph :: (Ptr TensorList -> IO (Ptr TensorList)) -> Ptr TensorList -> IO (Ptr (SharedPtr JitGraph))
traceAsGraph :: (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList -> IO (Ptr (SharedPtr JitGraph))
traceAsGraph Ptr TensorList -> IO (Ptr TensorList)
func Ptr TensorList
inputs =
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TensorList -> IO (Ptr TensorList)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[Safe.throwBlock| std::shared_ptr<torch::jit::Graph>* {
torch::jit::script::Module self("MyModule");
auto vars_in = *$(std::vector<at::Tensor>* inputs);
auto tfunc = $(void* (*funcPtr)(void*));
typedef std::vector<at::Tensor>* (*Func)(std::vector<at::Tensor>*);
auto func = (Func)tfunc;
auto graph = torch::jit::tracer::trace(
c10::fmap<c10::IValue>(vars_in),
[&func](c10::Stack in) -> c10::Stack {
std::vector<at::Tensor>* ivalue_inps = new std::vector<at::Tensor>(c10::fmap(in, [](const c10::IValue& v){
return torch::autograd::Variable(v.toTensor());
}));
std::vector<at::Tensor> out = *(func(ivalue_inps));
return c10::fmap<c10::IValue>(out);
},
[](const torch::autograd::Variable& var) { return "";}
).first->graph;
return new std::shared_ptr<torch::jit::Graph>(graph);
}|]
withJitGraph :: Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
withJitGraph :: forall a.
Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
withJitGraph Ptr (SharedPtr JitGraph)
graph Ptr JitGraph -> IO a
callback = do
Ptr JitGraph
v <-
[C.throwBlock| torch::jit::Graph* {
return (*$(std::shared_ptr<torch::jit::Graph>* graph)).get();
}|]
Ptr JitGraph -> IO a
callback Ptr JitGraph
v
graphOutputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphOutputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphOutputs Ptr JitGraph
graph = do
IORef [Ptr JitValue]
nodes <- forall a. a -> IO (IORef a)
newIORef []
let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
[Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes
forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
nodes (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[Safe.throwBlock| void {
auto tfunc = $(void* (*funcPtr)(void*));
typedef torch::jit::Value* (*Func)(torch::jit::Value*);
auto func = (Func)tfunc;
for(auto i : (*$(torch::jit::Graph* graph)).outputs()){
func(i);
}
}|]
forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes
graphInputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphInputs :: Ptr JitGraph -> IO [Ptr JitValue]
graphInputs Ptr JitGraph
graph = do
IORef [Ptr JitValue]
nodes <- forall a. a -> IO (IORef a)
newIORef []
let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
[Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes
forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
nodes (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[Safe.throwBlock| void {
auto tfunc = $(void* (*funcPtr)(void*));
typedef torch::jit::Value* (*Func)(torch::jit::Value*);
auto func = (Func)tfunc;
for(auto i : (*$(torch::jit::Graph* graph)).inputs()){
func(i);
}
}|]
forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
nodes
graphNodes :: Ptr JitGraph -> IO [Ptr JitNode]
graphNodes :: Ptr JitGraph -> IO [Ptr JitNode]
graphNodes Ptr JitGraph
graph = do
IORef [Ptr JitNode]
nodes <- forall a. a -> IO (IORef a)
newIORef []
let func :: Ptr JitNode -> IO (Ptr JitNode)
func Ptr JitNode
v = do
[Ptr JitNode]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitNode]
nodes
forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitNode]
nodes (Ptr JitNode
v forall a. a -> [a] -> [a]
: [Ptr JitNode]
r)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitNode
v
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitNode -> IO (Ptr JitNode)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[Safe.throwBlock| void {
auto tfunc = $(void* (*funcPtr)(void*));
typedef torch::jit::Node* (*Func)(torch::jit::Node*);
auto func = (Func)tfunc;
for(auto i : (*$(torch::jit::Graph* graph)).block()->nodes()){
func(i);
}
}|]
forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitNode]
nodes
nodeInputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeInputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeInputs Ptr JitNode
node = do
IORef [Ptr JitValue]
values <- forall a. a -> IO (IORef a)
newIORef []
let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
[Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values
forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
values (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[Safe.throwBlock| void {
auto tfunc = $(void* (*funcPtr)(void*));
typedef torch::jit::Value* (*Func)(torch::jit::Value*);
auto func = (Func)tfunc;
for(auto i : (*$(torch::jit::Node* node)).inputs()){
func(i);
}
}|]
forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values
nodeOutputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeOutputs :: Ptr JitNode -> IO [Ptr JitValue]
nodeOutputs Ptr JitNode
node = do
IORef [Ptr JitValue]
values <- forall a. a -> IO (IORef a)
newIORef []
let func :: Ptr JitValue -> IO (Ptr JitValue)
func Ptr JitValue
v = do
[Ptr JitValue]
r <- forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values
forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr JitValue]
values (Ptr JitValue
v forall a. a -> [a] -> [a]
: [Ptr JitValue]
r)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr JitValue
v
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper forall a b. (a -> b) -> a -> b
$ \Ptr ()
inputs' -> forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr JitValue -> IO (Ptr JitValue)
func (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
inputs'))
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[Safe.throwBlock| void {
auto tfunc = $(void* (*funcPtr)(void*));
typedef torch::jit::Value* (*Func)(torch::jit::Value*);
auto func = (Func)tfunc;
for(auto i : (*$(torch::jit::Node* node)).outputs()){
func(i);
}
}|]
forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef [Ptr JitValue]
values
nodeKind :: Ptr JitNode -> IO (Ptr StdString)
nodeKind :: Ptr JitNode -> IO (Ptr StdString)
nodeKind Ptr JitNode
node =
[C.throwBlock| std::string* {
return new std::string((*$(torch::jit::Node* node)).kind().toQualString());
}|]
valueId :: Ptr JitValue -> IO CInt
valueId :: Ptr JitValue -> IO CInt
valueId Ptr JitValue
value =
[C.throwBlock| int {
return (*$(torch::jit::Value* value)).unique();
}|]
valueType :: Ptr JitValue -> IO (Ptr StdString)
valueType :: Ptr JitValue -> IO (Ptr StdString)
valueType Ptr JitValue
node =
[C.throwBlock| std::string* {
return new std::string((*$(torch::jit::Value* node)).type()->str());
}|]
printGraph :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printGraph :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printGraph Ptr (SharedPtr JitGraph)
graph =
[C.throwBlock| std::string* {
return new std::string((**$(std::shared_ptr<torch::jit::Graph>* graph)).toString());
}|]
printOnnx :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printOnnx :: Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
printOnnx Ptr (SharedPtr JitGraph)
graph =
[C.throwBlock| std::string* {
auto graph_str = torch::jit::pretty_print_onnx(
*$(std::shared_ptr<torch::jit::Graph>* graph),
std::map<std::string, at::Tensor>{},
9,
false);
return new std::string(graph_str);
}|]
dumpToStr ::
Ptr Module ->
CBool ->
CBool ->
CBool ->
IO (Ptr StdString)
dumpToStr :: Ptr Module -> CBool -> CBool -> CBool -> IO (Ptr StdString)
dumpToStr Ptr Module
obj CBool
print_method_bodies CBool
print_attr_values CBool
print_param_values =
[C.throwBlock| std::string* {
return new std::string($(torch::jit::script::Module* obj)->dump_to_str(
$(bool print_method_bodies)
, $(bool print_attr_values)
, $(bool print_param_values)
));
}|]