{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Script where
import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_, replicateM)
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import Numeric
import System.IO.Unsafe
import Torch.Autograd
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import Torch.Internal.Managed.Type.IValue
import qualified Torch.Internal.Managed.Type.Module as LibTorch
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import Torch.Internal.Type (TensorList)
import qualified Torch.Internal.Type as ATen
import Torch.Internal.Unmanaged.Type.C10Dict
import Torch.Internal.Unmanaged.Type.IValue (IValueLike (..))
import qualified Torch.Internal.Unmanaged.Type.Module as Unmanaged
import Torch.NN
import Torch.Tensor (Tensor (..), toDevice)
import Torch.TensorOptions
newtype ScriptModule = UnsafeScriptModule (ForeignPtr ATen.Module)
newtype RawModule = UnsafeRawModule (ForeignPtr ATen.Module)
instance Show ScriptModule where
show :: ScriptModule -> String
show ScriptModule
obj = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ ScriptModule -> IO String
dumpToStr' ScriptModule
obj
type RawIValue = ForeignPtr ATen.IValue
newtype Blob = UnsafeBlob (ForeignPtr (ATen.C10Ptr ATen.Blob))
newtype Object = UnsafeObject (ForeignPtr (ATen.C10Ptr ATen.IVObject))
newtype Future = UnsafeFuture (ForeignPtr (ATen.C10Ptr ATen.IVFuture))
newtype Capsule = UnsafeCapsule (ForeignPtr (ATen.C10Ptr ATen.Capsule))
newtype Graph = UnsafeGraph (ForeignPtr (ATen.SharedPtr ATen.JitGraph))
data JitGraph = JitGraph
{ JitGraph -> [JitValue]
graphInputs :: [JitValue],
JitGraph -> [JitValue]
graphOutputs :: [JitValue],
JitGraph -> [JitNode]
graphNodes :: [JitNode]
}
deriving (Int -> JitGraph -> ShowS
[JitGraph] -> ShowS
JitGraph -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JitGraph] -> ShowS
$cshowList :: [JitGraph] -> ShowS
show :: JitGraph -> String
$cshow :: JitGraph -> String
showsPrec :: Int -> JitGraph -> ShowS
$cshowsPrec :: Int -> JitGraph -> ShowS
Show, JitGraph -> JitGraph -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JitGraph -> JitGraph -> Bool
$c/= :: JitGraph -> JitGraph -> Bool
== :: JitGraph -> JitGraph -> Bool
$c== :: JitGraph -> JitGraph -> Bool
Eq)
data JitNode = JitNode
{ JitNode -> [JitValue]
nodeInputs :: [JitValue],
JitNode -> [JitValue]
nodeOutputs :: [JitValue],
JitNode -> String
nodeKind :: String
}
deriving (Int -> JitNode -> ShowS
[JitNode] -> ShowS
JitNode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JitNode] -> ShowS
$cshowList :: [JitNode] -> ShowS
show :: JitNode -> String
$cshow :: JitNode -> String
showsPrec :: Int -> JitNode -> ShowS
$cshowsPrec :: Int -> JitNode -> ShowS
Show, JitNode -> JitNode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JitNode -> JitNode -> Bool
$c/= :: JitNode -> JitNode -> Bool
== :: JitNode -> JitNode -> Bool
$c== :: JitNode -> JitNode -> Bool
Eq)
data JitValue = JitValue
{ JitValue -> Int
valueId :: Int,
JitValue -> String
valueType :: String
}
deriving (Int -> JitValue -> ShowS
[JitValue] -> ShowS
JitValue -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JitValue] -> ShowS
$cshowList :: [JitValue] -> ShowS
show :: JitValue -> String
$cshow :: JitValue -> String
showsPrec :: Int -> JitValue -> ShowS
$cshowsPrec :: Int -> JitValue -> ShowS
Show, JitValue -> JitValue -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JitValue -> JitValue -> Bool
$c/= :: JitValue -> JitValue -> Bool
== :: JitValue -> JitValue -> Bool
$c== :: JitValue -> JitValue -> Bool
Eq)
instance Show Blob where
show :: Blob -> String
show Blob
_ = String
"Blob"
instance Show Future where
show :: Future -> String
show Future
_ = String
"Future"
instance Show Object where
show :: Object -> String
show Object
_ = String
"Object"
instance Show Capsule where
show :: Capsule -> String
show Capsule
_ = String
"Capsule"
data IValue
= IVNone
| IVTensor Tensor
| IVDouble Double
| IVInt Int64
| IVBool Bool
| IVTuple [IValue]
| IVIntList [Int64]
| IVDoubleList [Double]
| IVBoolList [Bool]
| IVString String
| IVTensorList [Tensor]
| IVBlob
| IVGenericList [IValue]
| IVGenericDict [(IValue, IValue)]
| IVFuture
| IVDevice
| IVObject
| IVUninitialized
| IVCapsule
deriving (Int -> IValue -> ShowS
[IValue] -> ShowS
IValue -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IValue] -> ShowS
$cshowList :: [IValue] -> ShowS
show :: IValue -> String
$cshow :: IValue -> String
showsPrec :: Int -> IValue -> ShowS
$cshowsPrec :: Int -> IValue -> ShowS
Show)
instance Castable ScriptModule (ForeignPtr ATen.Module) where
cast :: forall r. ScriptModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeScriptModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
uncast :: forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
uncast ForeignPtr Module
obj ScriptModule -> IO r
f = ScriptModule -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
obj
instance Castable RawModule (ForeignPtr ATen.Module) where
cast :: forall r. RawModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeRawModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
uncast :: forall r. ForeignPtr Module -> (RawModule -> IO r) -> IO r
uncast ForeignPtr Module
obj RawModule -> IO r
f = RawModule -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> RawModule
UnsafeRawModule ForeignPtr Module
obj
instance Castable Graph (ForeignPtr (ATen.SharedPtr ATen.JitGraph)) where
cast :: forall r.
Graph -> (ForeignPtr (SharedPtr JitGraph) -> IO r) -> IO r
cast (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj) ForeignPtr (SharedPtr JitGraph) -> IO r
f = ForeignPtr (SharedPtr JitGraph) -> IO r
f ForeignPtr (SharedPtr JitGraph)
obj
uncast :: forall r.
ForeignPtr (SharedPtr JitGraph) -> (Graph -> IO r) -> IO r
uncast ForeignPtr (SharedPtr JitGraph)
obj Graph -> IO r
f = Graph -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr (SharedPtr JitGraph) -> Graph
UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj
newModule :: String -> IO RawModule
newModule :: String -> IO RawModule
newModule = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr StdString -> IO (ForeignPtr Module)
LibTorch.newModule
saveScript :: ScriptModule -> FilePath -> IO ()
saveScript :: ScriptModule -> String -> IO ()
saveScript = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save
saveScript' :: RawModule -> FilePath -> IO ()
saveScript' :: RawModule -> String -> IO ()
saveScript' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save
data LoadMode
= WithoutRequiredGrad
| WithRequiredGrad
deriving (Int -> LoadMode -> ShowS
[LoadMode] -> ShowS
LoadMode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoadMode] -> ShowS
$cshowList :: [LoadMode] -> ShowS
show :: LoadMode -> String
$cshow :: LoadMode -> String
showsPrec :: Int -> LoadMode -> ShowS
$cshowsPrec :: Int -> LoadMode -> ShowS
Show, LoadMode -> LoadMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LoadMode -> LoadMode -> Bool
$c/= :: LoadMode -> LoadMode -> Bool
== :: LoadMode -> LoadMode -> Bool
$c== :: LoadMode -> LoadMode -> Bool
Eq)
loadScript :: LoadMode -> FilePath -> IO ScriptModule
loadScript :: LoadMode -> String -> IO ScriptModule
loadScript LoadMode
WithoutRequiredGrad String
file = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
loadScript LoadMode
WithRequiredGrad String
file = do
module' :: RawModule
module'@(UnsafeRawModule ForeignPtr Module
rmodule) <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
[Tensor]
params <- RawModule -> IO [Tensor]
getParametersIO RawModule
module'
[IndependentTensor]
paramsWithRequiredGrad <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Tensor]
params Tensor -> IO IndependentTensor
makeIndependent
RawModule -> [Tensor] -> IO ()
setParameters RawModule
module' (forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
paramsWithRequiredGrad)
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
rmodule)
loadScript' :: FilePath -> IO RawModule
loadScript' :: String -> IO RawModule
loadScript' = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load
instance HasForward ScriptModule [IValue] IValue where
forward :: ScriptModule -> [IValue] -> IValue
forward ScriptModule
module' = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch ScriptModule
module'
forwardStoch :: ScriptModule -> [IValue] -> IO IValue
forwardStoch = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ScriptModule -> [RawIValue] -> IO RawIValue
forward'
where
forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
forward' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr (StdVector IValue) -> IO RawIValue
LibTorch.forward
registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
LibTorch.registerParameter
registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ()
LibTorch.registerModule
getParameters ::
ScriptModule ->
[Tensor]
getParameters :: ScriptModule -> [Tensor]
getParameters = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters
getParametersIO ::
RawModule ->
IO [Tensor]
getParametersIO :: RawModule -> IO [Tensor]
getParametersIO = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters
setParameters :: RawModule -> [Tensor] -> IO ()
setParameters :: RawModule -> [Tensor] -> IO ()
setParameters = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters
updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters LoadMode
mode ScriptModule
module' [Tensor]
inputs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
case LoadMode
mode of
LoadMode
WithoutRequiredGrad -> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
LoadMode
WithRequiredGrad -> do
ScriptModule
r <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
[IndependentTensor]
paramsWithRequiredGrad <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Tensor]
inputs Tensor -> IO IndependentTensor
makeIndependent
ScriptModule -> [Tensor] -> IO ()
setParameters' ScriptModule
r (forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
paramsWithRequiredGrad)
forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
r
where
setParameters' :: ScriptModule -> [Tensor] -> IO ()
setParameters' :: ScriptModule -> [Tensor] -> IO ()
setParameters' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters
getNamedParameters ::
ScriptModule ->
[(String, Tensor)]
getNamedParameters :: ScriptModule -> [(String, Tensor)]
getNamedParameters (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Tensor)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedParameters ForeignPtr Module
m
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Tensor)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
(,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedBuffers ::
ScriptModule ->
[(String, Tensor)]
getNamedBuffers :: ScriptModule -> [(String, Tensor)]
getNamedBuffers (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Tensor)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedBuffers ForeignPtr Module
m
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Tensor)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
(,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedAttributes ::
ScriptModule ->
[(String, IValue)]
getNamedAttributes :: ScriptModule -> [(String, IValue)]
getNamedAttributes (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, RawIValue)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, RawIValue)]
LibTorch.getNamedAttributes ForeignPtr Module
m
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, RawIValue)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, RawIValue
value) ->
(,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
value forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedModules ::
ScriptModule ->
[(String, ScriptModule)]
getNamedModules :: ScriptModule -> [(String, ScriptModule)]
getNamedModules (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Module)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedModules ForeignPtr Module
m
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Module)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
(,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedChildren ::
ScriptModule ->
[(String, ScriptModule)]
getNamedChildren :: ScriptModule -> [(String, ScriptModule)]
getNamedChildren (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Module)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedChildren ForeignPtr Module
m
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Module)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
(,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value forall (m :: * -> *) a. Monad m => a -> m a
return
toScriptModule :: RawModule -> IO ScriptModule
toScriptModule :: RawModule -> IO ScriptModule
toScriptModule RawModule
rawModule = do
(UnsafeRawModule ForeignPtr Module
r) <- RawModule -> IO RawModule
cloneRawModule RawModule
rawModule
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
r
toRawModule :: ScriptModule -> IO RawModule
toRawModule :: ScriptModule -> IO RawModule
toRawModule ScriptModule
scriptModule = do
(UnsafeScriptModule ForeignPtr Module
r) <- ScriptModule -> IO ScriptModule
clone' ScriptModule
scriptModule
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> RawModule
UnsafeRawModule ForeignPtr Module
r
where
clone' :: ScriptModule -> IO ScriptModule
clone' = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone
cloneRawModule :: RawModule -> IO RawModule
cloneRawModule :: RawModule -> IO RawModule
cloneRawModule = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone
data RuntimeMode = Eval | Train deriving (Int -> RuntimeMode -> ShowS
[RuntimeMode] -> ShowS
RuntimeMode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RuntimeMode] -> ShowS
$cshowList :: [RuntimeMode] -> ShowS
show :: RuntimeMode -> String
$cshow :: RuntimeMode -> String
showsPrec :: Int -> RuntimeMode -> ShowS
$cshowsPrec :: Int -> RuntimeMode -> ShowS
Show, RuntimeMode -> RuntimeMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RuntimeMode -> RuntimeMode -> Bool
$c/= :: RuntimeMode -> RuntimeMode -> Bool
== :: RuntimeMode -> RuntimeMode -> Bool
$c== :: RuntimeMode -> RuntimeMode -> Bool
Eq)
setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode RawModule
rmod RuntimeMode
mode = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> CBool -> IO ()
LibTorch.train RawModule
rmod (RuntimeMode
mode forall a. Eq a => a -> a -> Bool
== RuntimeMode
Train)
define :: RawModule -> String -> IO ()
define :: RawModule -> String -> IO ()
define = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr StdString -> IO ()
LibTorch.define
dumpToStr ::
ScriptModule ->
Bool ->
Bool ->
Bool ->
IO String
dumpToStr :: ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString)
LibTorch.dumpToStr
dumpToStr' :: ScriptModule -> IO String
dumpToStr' :: ScriptModule -> IO String
dumpToStr' ScriptModule
obj = ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr ScriptModule
obj Bool
True Bool
True Bool
True
runMethod ::
ScriptModule ->
String ->
[IValue] ->
IValue
runMethod :: ScriptModule -> String -> [IValue] -> IValue
runMethod ScriptModule
module' String
func = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' ScriptModule
module' String
func
where
runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue)
LibTorch.runMethod
runMethod1 ::
ScriptModule ->
String ->
IValue ->
IValue
runMethod1 :: ScriptModule -> String -> IValue -> IValue
runMethod1 ScriptModule
module' String
func = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' ScriptModule
module' String
func
where
runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> RawIValue -> IO (Ptr IValue)
LibTorch.runMethod1
instance Parameterized ScriptModule where
flattenParameters :: ScriptModule -> [IndependentTensor]
flattenParameters ScriptModule
module' = forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor forall a b. (a -> b) -> a -> b
$ ScriptModule -> [Tensor]
getParameters ScriptModule
module'
_replaceParameters :: ScriptModule -> ParamStream ScriptModule
_replaceParameters ScriptModule
module' = do
let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length (ScriptModule -> [Tensor]
getParameters ScriptModule
module')
[IndependentTensor]
ps' <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
len ParamStream IndependentTensor
nextParameter
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters LoadMode
WithRequiredGrad ScriptModule
module' (forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
ps')
trace ::
String ->
String ->
([Tensor] -> IO [Tensor]) ->
[Tensor] ->
IO RawModule
trace :: String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace String
moduleName String
functionName [Tensor] -> IO [Tensor]
func = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 (\String
m String
f ForeignPtr TensorList
inps -> String
-> String
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList
-> IO (ForeignPtr Module)
LibTorch.trace String
m String
f (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func) ForeignPtr TensorList
inps) String
moduleName String
functionName
where
trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
[Tensor]
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
ret forall (m :: * -> *) a. Monad m => a -> m a
return
traceWithParameters ::
Parameterized f =>
String ->
(f -> [Tensor] -> IO [Tensor]) ->
f ->
[Tensor] ->
IO RawModule
traceWithParameters :: forall f.
Parameterized f =>
String
-> (f -> [Tensor] -> IO [Tensor]) -> f -> [Tensor] -> IO RawModule
traceWithParameters String
moduleName f -> [Tensor] -> IO [Tensor]
func f
parameterized_parameters [Tensor]
inputs = do
let parameters :: [Tensor]
parameters = forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent (forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters f
parameterized_parameters)
fromParams :: [Tensor] -> f
fromParams [Tensor]
params = forall f. Parameterized f => f -> [IndependentTensor] -> f
replaceParameters f
parameterized_parameters (forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor [Tensor]
params)
plen :: Int
plen = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
parameters
ilen :: Int
ilen = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
inputs
RawModule
r <-
String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace
String
moduleName
String
"forwardWithParameters"
( \[Tensor]
parametersAndInputs ->
f -> [Tensor] -> IO [Tensor]
func
([Tensor] -> f
fromParams (forall a. Int -> [a] -> [a]
take Int
plen [Tensor]
parametersAndInputs))
(forall a. Int -> [a] -> [a]
drop Int
plen [Tensor]
parametersAndInputs)
)
([Tensor]
parameters forall a. [a] -> [a] -> [a]
++ [Tensor]
inputs)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [Tensor]
parameters) forall a b. (a -> b) -> a -> b
$ \(Integer
i, Tensor
p) ->
RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter RawModule
r (String
"p" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Integer
i) Tensor
p Bool
False
let args :: String
args = forall a. [a] -> [[a]] -> [a]
intercalate String
", " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"i" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
ilen forall a. Num a => a -> a -> a
-Int
1)]
params :: String
params = forall a. [a] -> [[a]] -> [a]
intercalate String
", " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"self.p" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
plen forall a. Num a => a -> a -> a
-Int
1)]
RawModule -> String -> IO ()
define RawModule
r forall a b. (a -> b) -> a -> b
$
String
"def forward(self, " forall a. [a] -> [a] -> [a]
++ String
args forall a. [a] -> [a] -> [a]
++ String
"):\n" forall a. [a] -> [a] -> [a]
++ String
" return self.forwardWithParameters(" forall a. [a] -> [a] -> [a]
++ String
params forall a. [a] -> [a] -> [a]
++ String
", " forall a. [a] -> [a] -> [a]
++ String
args forall a. [a] -> [a] -> [a]
++ String
" )\n"
forall (m :: * -> *) a. Monad m => a -> m a
return RawModule
r
traceAsGraph ::
([Tensor] -> IO [Tensor]) ->
[Tensor] ->
IO Graph
traceAsGraph :: ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO Graph
traceAsGraph [Tensor] -> IO [Tensor]
func = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ((ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
LibTorch.traceAsGraph (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func))
where
trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
[Tensor]
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
ret forall (m :: * -> *) a. Monad m => a -> m a
return
printGraph :: Graph -> IO String
printGraph :: Graph -> IO String
printGraph = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printGraph
printOnnx :: Graph -> IO String
printOnnx :: Graph -> IO String
printOnnx = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printOnnx
graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
graph) =
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (SharedPtr JitGraph)
graph forall a b. (a -> b) -> a -> b
$ \Ptr (SharedPtr JitGraph)
g0 -> forall a.
Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
Unmanaged.withJitGraph Ptr (SharedPtr JitGraph)
g0 forall a b. (a -> b) -> a -> b
$ \Ptr JitGraph
g -> do
[JitValue]
graphInputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitValue]
Unmanaged.graphInputs Ptr JitGraph
g
[JitValue]
graphOutputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitValue]
Unmanaged.graphOutputs Ptr JitGraph
g
[JitNode]
graphNodes <- forall {t :: * -> *}.
Traversable t =>
t (Ptr JitNode) -> IO (t JitNode)
toJitNode forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitNode]
Unmanaged.graphNodes Ptr JitGraph
g
forall (f :: * -> *) a. Applicative f => a -> f a
pure JitGraph {[JitValue]
[JitNode]
graphNodes :: [JitNode]
graphOutputs :: [JitValue]
graphInputs :: [JitValue]
graphNodes :: [JitNode]
graphOutputs :: [JitValue]
graphInputs :: [JitValue]
..}
where
toJitValue :: t a -> IO (t JitValue)
toJitValue t a
inputs =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t a
inputs forall a b. (a -> b) -> a -> b
$ \a
i -> do
Int
valueId <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitValue -> IO CInt
Unmanaged.valueId a
i
String
valueType <- forall a ca. Castable a ca => IO ca -> IO a
cast0 (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitValue -> IO (Ptr StdString)
Unmanaged.valueType a
i :: IO (ForeignPtr ATen.StdString))
forall (f :: * -> *) a. Applicative f => a -> f a
pure JitValue {Int
String
valueType :: String
valueId :: Int
valueType :: String
valueId :: Int
..}
toJitNode :: t (Ptr JitNode) -> IO (t JitNode)
toJitNode t (Ptr JitNode)
nodes =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t (Ptr JitNode)
nodes forall a b. (a -> b) -> a -> b
$ \Ptr JitNode
n -> do
[JitValue]
nodeInputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitNode -> IO [Ptr JitValue]
Unmanaged.nodeInputs Ptr JitNode
n
[JitValue]
nodeOutputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitNode -> IO [Ptr JitValue]
Unmanaged.nodeOutputs Ptr JitNode
n
String
nodeKind <- forall a ca. Castable a ca => IO ca -> IO a
cast0 (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitNode -> IO (Ptr StdString)
Unmanaged.nodeKind Ptr JitNode
n :: IO (ForeignPtr ATen.StdString))
forall (f :: * -> *) a. Applicative f => a -> f a
pure JitNode {String
[JitValue]
nodeKind :: String
nodeOutputs :: [JitValue]
nodeInputs :: [JitValue]
nodeKind :: String
nodeOutputs :: [JitValue]
nodeInputs :: [JitValue]
..}
instance Castable [IValue] [RawIValue] where
cast :: forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
cast [IValue]
a [RawIValue] -> IO r
f = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [IValue]
a (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` forall (m :: * -> *) a. Monad m => a -> m a
return) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [RawIValue] -> IO r
f
uncast :: forall r. [RawIValue] -> ([IValue] -> IO r) -> IO r
uncast [RawIValue]
a [IValue] -> IO r
f = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [RawIValue]
a (forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` forall (m :: * -> *) a. Monad m => a -> m a
return) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [IValue] -> IO r
f
instance Castable IValue RawIValue where
cast :: forall r. IValue -> (RawIValue -> IO r) -> IO r
cast IValue
IVNone RawIValue -> IO r
f = IO RawIValue
newIValue forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVTensor (Unsafe ForeignPtr Tensor
v)) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr Tensor
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVDouble Double
v) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue Double
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVInt Int64
v) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue Int64
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVBool Bool
v) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue Bool
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVTuple [IValue]
v) RawIValue -> IO r
f = do
[RawIValue]
rawIValues <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
ForeignPtr (C10Ptr IVTuple)
c10tuple <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10Ptr IVTuple)
c10tuple
cast (IVIntList [Int64]
v) RawIValue -> IO r
f = do
ForeignPtr (C10List Int64)
v' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Int64]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List Int64))
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List Int64)
v'
cast (IVDoubleList [Double]
v) RawIValue -> IO r
f = do
[CDouble]
cdoubles <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Double]
v (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CDouble]
ForeignPtr (C10List CDouble)
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [CDouble]
cdoubles forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List CDouble))
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List CDouble)
c10list
cast (IVBoolList [Bool]
v) RawIValue -> IO r
f = do
[CBool]
cbools <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Bool]
v (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CBool]
ForeignPtr (C10List CBool)
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [CBool]
cbools forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List CBool))
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List CBool)
c10list
cast (IVString String
v) RawIValue -> IO r
f = do
ForeignPtr StdString
v' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast String
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.StdString)
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr StdString
v'
cast (IVTensorList [Tensor]
v) RawIValue -> IO r
f = do
ForeignPtr (C10List Tensor)
v' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List Tensor)
v'
cast (IVGenericList [IValue]
v) RawIValue -> IO r
f = do
[RawIValue]
rawIValues <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
ForeignPtr (C10List IValue)
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List ATen.IValue))
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List IValue)
c10list
cast (IVGenericDict [(IValue, IValue)]
v) RawIValue -> IO r
f = do
[RawIValue]
keys <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(IValue, IValue)]
v) forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
[RawIValue]
values <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(IValue, IValue)]
v) forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
let rawIValues :: [(RawIValue, RawIValue)]
rawIValues = forall a b. [a] -> [b] -> [(a, b)]
zip [RawIValue]
keys [RawIValue]
values
ForeignPtr (C10Dict '(IValue, IValue))
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [(RawIValue, RawIValue)]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10Dict '(IValue, IValue))
c10list
cast IValue
a RawIValue -> IO r
f = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"Unsupported data-type:" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show IValue
a
uncast :: forall r. RawIValue -> (IValue -> IO r) -> IO r
uncast RawIValue
obj IValue -> IO r
f =
forall {m :: * -> *} {a} {a}.
(MonadThrow m, Eq a, Num a) =>
[(m a, m a)] -> m a
select
[ (RawIValue -> IO CBool
iValue_isNone RawIValue
obj, IValue -> IO r
f IValue
IVNone),
(RawIValue -> IO CBool
iValue_isTensor RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IValue
IVTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe),
(RawIValue -> IO CBool
iValue_isDouble RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> IValue
IVDouble),
(RawIValue -> IO CBool
iValue_isInt RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IValue
IVInt),
(RawIValue -> IO CBool
iValue_isBool RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> IValue
IVBool),
( RawIValue -> IO CBool
iValue_isString RawIValue
obj,
do
ForeignPtr StdString
v <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr ATen.StdString)
String
str <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO String
IValue -> IO r
f (String -> IValue
IVString String
str)
),
( RawIValue -> IO CBool
iValue_isTensorList RawIValue
obj,
do
ForeignPtr (C10List Tensor)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
[Tensor]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [Tensor]
IValue -> IO r
f ([Tensor] -> IValue
IVTensorList [Tensor]
ts)
),
( RawIValue -> IO CBool
iValue_isDoubleList RawIValue
obj,
do
ForeignPtr (C10List CDouble)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CDouble))
[CDouble]
cdoubles <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List CDouble)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [CDouble]
[Double]
doubles <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CDouble]
cdoubles (forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [Double]
IValue -> IO r
f ([Double] -> IValue
IVDoubleList [Double]
doubles)
),
( RawIValue -> IO CBool
iValue_isIntList RawIValue
obj,
do
ForeignPtr (C10List Int64)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List Int64))
[Int64]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Int64)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [Int64]
IValue -> IO r
f ([Int64] -> IValue
IVIntList [Int64]
ts)
),
( RawIValue -> IO CBool
iValue_isBoolList RawIValue
obj,
do
ForeignPtr (C10List CBool)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CBool))
[CBool]
cbools <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List CBool)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [CBool]
[Bool]
bools <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CBool]
cbools (forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [Bool]
IValue -> IO r
f ([Bool] -> IValue
IVBoolList [Bool]
bools)
),
( RawIValue -> IO CBool
iValue_isTuple RawIValue
obj,
do
ForeignPtr (C10Ptr IVTuple)
c10tuple <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
[RawIValue]
rawIValues <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10Ptr IVTuple)
c10tuple forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
[IValue]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [IValue]
IValue -> IO r
f ([IValue] -> IValue
IVTuple [IValue]
ts)
),
( RawIValue -> IO CBool
iValue_isList RawIValue
obj,
do
ForeignPtr (C10List IValue)
c10list <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.IValue))
[RawIValue]
rawIValues <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List IValue)
c10list forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
[IValue]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [IValue]
IValue -> IO r
f ([IValue] -> IValue
IVGenericList [IValue]
ts)
),
( RawIValue -> IO CBool
iValue_isGenericDict RawIValue
obj,
do
ForeignPtr (C10Dict '(IValue, IValue))
c10list <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
[(RawIValue, RawIValue)]
rawIValues <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10Dict '(IValue, IValue))
c10list forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [(RawIValue, RawIValue)]
[(IValue, IValue)]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(RawIValue, RawIValue)]
rawIValues forall a b. (a -> b) -> a -> b
$ \(RawIValue
a, RawIValue
b) -> do
IValue
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
a forall (m :: * -> *) a. Monad m => a -> m a
return
IValue
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
b forall (m :: * -> *) a. Monad m => a -> m a
return
forall (m :: * -> *) a. Monad m => a -> m a
return (IValue
a', IValue
b')
IValue -> IO r
f ([(IValue, IValue)] -> IValue
IVGenericDict [(IValue, IValue)]
ts)
),
(RawIValue -> IO CBool
iValue_isBlob RawIValue
obj, IValue -> IO r
f IValue
IVBlob),
(RawIValue -> IO CBool
iValue_isFuture RawIValue
obj, IValue -> IO r
f IValue
IVFuture),
(RawIValue -> IO CBool
iValue_isDevice RawIValue
obj, IValue -> IO r
f IValue
IVDevice),
(RawIValue -> IO CBool
iValue_isObject RawIValue
obj, IValue -> IO r
f IValue
IVObject),
(RawIValue -> IO CBool
iValue_isCapsule RawIValue
obj, IValue -> IO r
f IValue
IVCapsule)
]
where
select :: [(m a, m a)] -> m a
select [] = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Unsupported IValue"
select ((m a
cond, m a
body) : [(m a, m a)]
xs) =
m a
cond forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
a
1 -> m a
body
a
_ -> [(m a, m a)] -> m a
select [(m a, m a)]
xs