{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}



module Torch.Internal.Managed.Type.Module where


import Foreign.C.String
import Foreign.C.Types
import Foreign
import Foreign.ForeignPtr.Unsafe
import Torch.Internal.Type
import Torch.Internal.Class
import Torch.Internal.Cast
import Torch.Internal.Objects
import Control.Monad(forM)
import Control.Concurrent.MVar (MVar(..), newEmptyMVar, putMVar, takeMVar)

import qualified Torch.Internal.Unmanaged.Type.Module as Unmanaged

newModule :: ForeignPtr StdString -> IO (ForeignPtr Module)
newModule :: ForeignPtr StdString -> IO (ForeignPtr Module)
newModule = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr StdString -> IO (Ptr Module)
Unmanaged.newModule

save :: ForeignPtr Module -> FilePath -> IO ()
save :: ForeignPtr Module -> String -> IO ()
save = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> String -> IO ()
Unmanaged.save

load :: FilePath -> IO (ForeignPtr Module)
load :: String -> IO (ForeignPtr Module)
load = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 String -> IO (Ptr Module)
Unmanaged.load

forward :: ForeignPtr Module -> (ForeignPtr (StdVector IValue)) -> IO (ForeignPtr IValue)
forward :: ForeignPtr Module
-> ForeignPtr (StdVector IValue) -> IO (ForeignPtr IValue)
forward = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> Ptr (StdVector IValue) -> IO (Ptr IValue)
Unmanaged.forward

registerParameter :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
registerParameter :: ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
registerParameter = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
_cast4 Ptr Module -> Ptr StdString -> Ptr Tensor -> CBool -> IO ()
Unmanaged.registerParameter

registerModule :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr Module -> IO ()
registerModule :: ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ()
registerModule = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module -> Ptr StdString -> Ptr Module -> IO ()
Unmanaged.registerModule

train :: ForeignPtr Module -> CBool -> IO ()
train :: ForeignPtr Module -> CBool -> IO ()
train = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> CBool -> IO ()
Unmanaged.train

runMethod :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr (C10List IValue) -> IO (Ptr IValue)
runMethod :: ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue)
runMethod = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module
-> Ptr StdString -> Ptr (C10List IValue) -> IO (Ptr IValue)
Unmanaged.runMethod

runMethod1 :: ForeignPtr Module -> ForeignPtr StdString -> ForeignPtr IValue -> IO (Ptr IValue)
runMethod1 :: ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr IValue -> IO (Ptr IValue)
runMethod1 = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module -> Ptr StdString -> Ptr IValue -> IO (Ptr IValue)
Unmanaged.runMethod1

getParameters :: ForeignPtr Module -> IO (ForeignPtr TensorList)
getParameters :: ForeignPtr Module -> IO (ForeignPtr TensorList)
getParameters = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr Module -> IO (Ptr TensorList)
Unmanaged.getParameters

setParameters :: ForeignPtr Module -> ForeignPtr TensorList -> IO ()
setParameters :: ForeignPtr Module -> ForeignPtr TensorList -> IO ()
setParameters = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> Ptr TensorList -> IO ()
Unmanaged.setParameters

getNamedParameters :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Tensor)]
getNamedParameters :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
getNamedParameters ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
  [(Ptr StdString, Ptr Tensor)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
Unmanaged.getNamedParameters Ptr Module
obj'
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Tensor)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Tensor
b) -> do
    ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return 
    ForeignPtr Tensor
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Tensor
b forall (m :: * -> *) a. Monad m => a -> m a
return
    forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Tensor
b')

getNamedBuffers :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Tensor)]
getNamedBuffers :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
getNamedBuffers ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
  [(Ptr StdString, Ptr Tensor)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Tensor)]
Unmanaged.getNamedBuffers Ptr Module
obj'
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Tensor)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Tensor
b) -> do
    ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return 
    ForeignPtr Tensor
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Tensor
b forall (m :: * -> *) a. Monad m => a -> m a
return
    forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Tensor
b')

getNamedAttributes :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr IValue)]
getNamedAttributes :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr IValue)]
getNamedAttributes ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
  [(Ptr StdString, Ptr IValue)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr IValue)]
Unmanaged.getNamedAttributes Ptr Module
obj'
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr IValue)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr IValue
b) -> do
    ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return 
    ForeignPtr IValue
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr IValue
b forall (m :: * -> *) a. Monad m => a -> m a
return
    forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr IValue
b')

getNamedModules :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Module)]
getNamedModules :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
getNamedModules ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
  [(Ptr StdString, Ptr Module)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Module)]
Unmanaged.getNamedModules Ptr Module
obj'
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Module)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Module
b) -> do
    ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return 
    ForeignPtr Module
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Module
b forall (m :: * -> *) a. Monad m => a -> m a
return
    forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Module
b')

getNamedChildren :: ForeignPtr Module -> IO [(ForeignPtr StdString,ForeignPtr Module)]
getNamedChildren :: ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
getNamedChildren ForeignPtr Module
obj = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Module
obj forall a b. (a -> b) -> a -> b
$ \Ptr Module
obj' -> do
  [(Ptr StdString, Ptr Module)]
v <- Ptr Module -> IO [(Ptr StdString, Ptr Module)]
Unmanaged.getNamedChildren Ptr Module
obj'
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ptr StdString, Ptr Module)]
v forall a b. (a -> b) -> a -> b
$ \(Ptr StdString
a,Ptr Module
b) -> do
    ForeignPtr StdString
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr StdString
a forall (m :: * -> *) a. Monad m => a -> m a
return 
    ForeignPtr Module
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast Ptr Module
b forall (m :: * -> *) a. Monad m => a -> m a
return
    forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr StdString
a',ForeignPtr Module
b')

toDevice :: ForeignPtr Module -> DeviceType -> Int16 -> IO ()
toDevice :: ForeignPtr Module -> DeviceType -> DeviceType -> IO ()
toDevice = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
_cast3 Ptr Module -> DeviceType -> DeviceType -> IO ()
Unmanaged.toDevice

clone :: ForeignPtr Module -> IO (ForeignPtr Module)
clone :: ForeignPtr Module -> IO (ForeignPtr Module)
clone = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr Module -> IO (Ptr Module)
Unmanaged.clone

define :: ForeignPtr Module -> ForeignPtr StdString -> IO ()
define :: ForeignPtr Module -> ForeignPtr StdString -> IO ()
define = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Module -> Ptr StdString -> IO ()
Unmanaged.define


-- Note: Not to release "ForeignPtr TensorList" before calling trace, put the pointer to MVar, and touch the reference.
trace :: String -> String -> (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> ForeignPtr TensorList -> IO (ForeignPtr Module)
trace :: String
-> String
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList
-> IO (ForeignPtr Module)
trace String
moduleName String
functionName ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs = do
  MVar (ForeignPtr TensorList)
ref <- forall a. IO (MVar a)
newEmptyMVar
  ForeignPtr Module
ret <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 (\CString
m CString
f Ptr TensorList
inps -> CString
-> CString
-> (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList
-> IO (Ptr Module)
Unmanaged.trace CString
m CString
f (MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func) Ptr TensorList
inps) String
moduleName String
functionName ForeignPtr TensorList
inputs
  ForeignPtr TensorList
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr TensorList)
ref
  forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr TensorList
v
  forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Module
ret
  where
    trans :: MVar (ForeignPtr TensorList) -> (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> Ptr TensorList -> IO (Ptr TensorList)
    trans :: MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func Ptr TensorList
inputs = do
      ForeignPtr TensorList
inputs' <- forall a. CppObject a => Ptr a -> IO (ForeignPtr a)
fromPtr Ptr TensorList
inputs
      ForeignPtr TensorList
ret <- ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs'
      forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList
ret
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr TensorList
ret

-- Note: Not to release "ForeignPtr TensorList" after calling trace, put the pointer to MVar, and touch the reference.
traceAsGraph :: (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
traceAsGraph :: (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
traceAsGraph ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs = do
  MVar (ForeignPtr TensorList)
ref <- forall a. IO (MVar a)
newEmptyMVar
  ForeignPtr (SharedPtr JitGraph)
ret <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 (\Ptr TensorList
inps -> (Ptr TensorList -> IO (Ptr TensorList))
-> Ptr TensorList -> IO (Ptr (SharedPtr JitGraph))
Unmanaged.traceAsGraph (MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func) Ptr TensorList
inps) ForeignPtr TensorList
inputs
  ForeignPtr TensorList
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr TensorList)
ref
  forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr TensorList
v
  forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr (SharedPtr JitGraph)
ret
  where
    trans :: MVar (ForeignPtr TensorList) -> (ForeignPtr TensorList -> IO (ForeignPtr TensorList)) -> Ptr TensorList -> IO (Ptr TensorList)
    trans :: MVar (ForeignPtr TensorList)
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Ptr TensorList
-> IO (Ptr TensorList)
trans MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func Ptr TensorList
inputs = do
      ForeignPtr TensorList
inputs' <- forall a. CppObject a => Ptr a -> IO (ForeignPtr a)
fromPtr Ptr TensorList
inputs
      ForeignPtr TensorList
ret <- ForeignPtr TensorList -> IO (ForeignPtr TensorList)
func ForeignPtr TensorList
inputs'
      forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr TensorList)
ref ForeignPtr TensorList
ret
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr TensorList
ret

printGraph :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printGraph :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printGraph ForeignPtr (SharedPtr JitGraph)
graph = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
Unmanaged.printGraph ForeignPtr (SharedPtr JitGraph)
graph

printOnnx :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printOnnx :: ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
printOnnx ForeignPtr (SharedPtr JitGraph)
graph = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr (SharedPtr JitGraph) -> IO (Ptr StdString)
Unmanaged.printOnnx ForeignPtr (SharedPtr JitGraph)
graph

dumpToStr
  :: ForeignPtr Module
  -> CBool
  -> CBool
  -> CBool
  -> IO (ForeignPtr StdString)
dumpToStr :: ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString)
dumpToStr = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
_cast4 Ptr Module -> CBool -> CBool -> CBool -> IO (Ptr StdString)
Unmanaged.dumpToStr