{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Internal.Managed.Type.Module where
import Foreign.C.String
import Foreign.C.Types
import Foreign
import Foreign.ForeignPtr.Unsafe
import Torch.Internal.Type
import Torch.Internal.Class
import Torch.Internal.Cast
import Torch.Internal.Objects
import Control.Monad(forM)
import Control.Concurrent.MVar (MVar(..), newEmptyMVar, putMVar, takeMVar)
import qualified Torch.Internal.Unmanaged.Type.Module as Unmanaged
newModule :: ForeignPtr StdString -> IO (ForeignPtr Module)
newModule :: ForeignPtr StdString -> IO (ForeignPtr Module)
newModule = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr StdString -> IO (Ptr Module)
Unmanaged.newModule
save :: ForeignPtr Module -> FilePath -> IO ()
save :: ForeignPtr Module -> String -> IO ()
save = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> String -> IO ()
Unmanaged.save
load :: FilePath -> IO (ForeignPtr Module)
load :: String -> IO (ForeignPtr Module)
load = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 String -> IO (Ptr Module)
Unmanaged.load
forward :: ForeignPtr Module -> (ForeignPtr (StdVector IValue)) -> IO (ForeignPtr IValue)
forward :: ForeignPtr Module
-> ForeignPtr (StdVector IValue) -> IO (ForeignPtr IValue)
forward = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> Ptr (StdVector IValue) -> IO (Ptr IValue)
Unmanaged.forward
registerParameter :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
registerParameter :: ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
registerParameter = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
_cast4 Ptr Module -> Ptr StdString -> Ptr Tensor -> CBool -> IO ()
Unmanaged.registerParameter
registerModule :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr Module -> IO ()
registerModule :: ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ()
registerModule = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module -> Ptr StdString -> Ptr Module -> IO ()
Unmanaged.registerModule
train :: ForeignPtr Module -> CBool -> IO ()
train :: ForeignPtr Module -> CBool -> IO ()
train = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> CBool -> IO ()
Unmanaged.train
runMethod :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr (C10List IValue) -> IO (Ptr IValue)
runMethod :: ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue)
runMethod = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module
-> Ptr StdString -> Ptr (C10List IValue) -> IO (Ptr IValue)
Unmanaged.runMethod
runMethod1 :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr IValue -> IO (Ptr IValue)
runMethod1 :: ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr IValue -> IO (Ptr IValue)
runMethod1 = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module -> Ptr StdString -> Ptr IValue -> IO (Ptr IValue)
Unmanaged.runMethod1
getParameters :: ForeignPtr Module -> IO (ForeignPtr TensorList)
getParameters :: ForeignPtr Module -> IO (ForeignPtr TensorList)
getParameters = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr Module -> IO (Ptr TensorList)
Unmanaged.getParameters
setParameters :: ForeignPtr Module -> ForeignPtr TensorList -> IO ()
setParameters :: ForeignPtr Module -> ForeignPtr TensorList -> IO ()
setParameters = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> Ptr TensorList -> IO ()
Unmanaged.setParameters
getNamedParameters :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Tensor)]
getNamedParameters :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
getNamedParameters ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
[(Ptr StdString, Ptr Tensor)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
Unmanaged.getNamedParameters Ptr Module
obj'
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Tensor)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Tensor
b) -> do
ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return
ForeignPtr Tensor
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Tensor
b forall (m :: * -> *) a. Monad m => a -> m a
return
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Tensor
b')
getNamedBuffers :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Tensor)]
getNamedBuffers :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
getNamedBuffers ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
[(Ptr StdString, Ptr Tensor)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
Unmanaged.getNamedBuffers Ptr Module
obj'
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Tensor)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Tensor
b) -> do
ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return
ForeignPtr Tensor
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Tensor
b forall (m :: * -> *) a. Monad m => a -> m a
return
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Tensor
b')
getNamedAttributes :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr IValue)]
getNamedAttributes :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr IValue)]
getNamedAttributes ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
[(Ptr StdString, Ptr IValue)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr IValue)]
Unmanaged.getNamedAttributes Ptr Module
obj'
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr IValue)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr IValue
b) -> do
ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return
ForeignPtr IValue
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr IValue
b forall (m :: * -> *) a. Monad m => a -> m a
return
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr IValue
b')
getNamedModules :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Module)]
getNamedModules :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
getNamedModules ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
[(Ptr StdString, Ptr Module)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Module)]
Unmanaged.getNamedModules Ptr Module
obj'
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Module)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Module
b) -> do
ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return
ForeignPtr Module
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Module
b forall (m :: * -> *) a. Monad m => a -> m a
return
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Module
b')
getNamedChildren :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Module)]
getNamedChildren :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
getNamedChildren ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
[(Ptr StdString, Ptr Module)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Module)]
Unmanaged.getNamedChildren Ptr Module
obj'
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Module)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Module
b) -> do
ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return
ForeignPtr Module
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Module
b forall (m :: * -> *) a. Monad m => a -> m a
return
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Module
b')
toDevice :: ForeignPtr Module -> DeviceType -> Int16 -> IO ()
toDevice :: ForeignPtr Module -> DeviceType -> DeviceType -> IO ()
toDevice = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module -> DeviceType -> DeviceType -> IO ()
Unmanaged.toDevice
clone :: ForeignPtr Module -> IO (ForeignPtr Module)
clone :: ForeignPtr Module -> IO (ForeignPtr Module)
clone = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr Module -> IO (Ptr Module)
Unmanaged.clone
define :: ForeignPtr Module -> ForeignPtr StdString -> IO ()
define :: ForeignPtr Module -> ForeignPtr StdString -> IO ()
define = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> Ptr StdString -> IO ()
Unmanaged.define
trace :: String -> String -> (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> ForeignPtr TensorList -> IO (ForeignPtr Module)
trace :: String
-> String
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList
-> IO (ForeignPtr Module)
trace String
moduleName String
functionName ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs = do
MVar (ForeignPtr TensorList)
ref <- forall a. IO (MVar a)
newEmptyMVar
ForeignPtr Module
ret <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 (\CString
m CString
f Ptr TensorList
inps -> CString
-> CString
-> (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList
-> IO (Ptr Module)
Unmanaged.trace CString
m CString
f (MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func) Ptr TensorList
inps) String
moduleName String
functionName ForeignPtr TensorList
inputs
ForeignPtr TensorList
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr TensorList)
ref
forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr TensorList
v
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Module
ret
where
trans :: MVar (ForeignPtr TensorList) -> (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> Ptr TensorList -> IO (Ptr TensorList)
trans :: MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func Ptr TensorList
inputs = do
ForeignPtr TensorList
inputs' <- forall a. CppObject a => Ptr a -> IO (ForeignPtr a)
fromPtr Ptr TensorList
inputs
ForeignPtr TensorList
ret <- ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs'
forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList
ret
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr TensorList
ret
traceAsGraph :: (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
traceAsGraph :: (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
traceAsGraph ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs = do
MVar (ForeignPtr TensorList)
ref <- forall a. IO (MVar a)
newEmptyMVar
ForeignPtr (SharedPtr JitGraph)
ret <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 (\Ptr TensorList
inps -> (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList -> IO (Ptr (SharedPtr JitGraph))
Unmanaged.traceAsGraph (MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func) Ptr TensorList
inps) ForeignPtr TensorList
inputs
ForeignPtr TensorList
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr TensorList)
ref
forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr TensorList
v
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr (SharedPtr JitGraph)
ret
where
trans :: MVar (ForeignPtr TensorList) -> (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> Ptr TensorList -> IO (Ptr TensorList)
trans :: MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func Ptr TensorList
inputs = do
ForeignPtr TensorList
inputs' <- forall a. CppObject a => Ptr a -> IO (ForeignPtr a)
fromPtr Ptr TensorList
inputs
ForeignPtr TensorList
ret <- ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs'
forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList
ret
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr TensorList
ret
printGraph :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printGraph :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printGraph ForeignPtr (SharedPtr JitGraph)
graph = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
Unmanaged.printGraph ForeignPtr (SharedPtr JitGraph)
graph
printOnnx :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printOnnx :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printOnnx ForeignPtr (SharedPtr JitGraph)
graph = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
Unmanaged.printOnnx ForeignPtr (SharedPtr JitGraph)
graph
dumpToStr
:: ForeignPtr Module
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr StdString)
dumpToStr :: ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString)
dumpToStr = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
_cast4 Ptr Module -> CBool -> CBool -> CBool -> IO (Ptr StdString)
Unmanaged.dumpToStr