{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Script where

import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_, replicateM)
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import Numeric
import System.IO.Unsafe
import Torch.Autograd
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import Torch.Internal.Managed.Type.IValue
import qualified Torch.Internal.Managed.Type.Module as LibTorch
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import Torch.Internal.Type (TensorList)
import qualified Torch.Internal.Type as ATen
import Torch.Internal.Unmanaged.Type.C10Dict
import Torch.Internal.Unmanaged.Type.IValue (IValueLike (..))
import qualified Torch.Internal.Unmanaged.Type.Module as Unmanaged
import Torch.NN
import Torch.Tensor (Tensor (..), toDevice)
import Torch.TensorOptions

newtype ScriptModule = UnsafeScriptModule (ForeignPtr ATen.Module)

newtype RawModule = UnsafeRawModule (ForeignPtr ATen.Module)

instance Show ScriptModule where
  show :: ScriptModule -> String
show ScriptModule
obj = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ ScriptModule -> IO String
dumpToStr' ScriptModule
obj

type RawIValue = ForeignPtr ATen.IValue

newtype Blob = UnsafeBlob (ForeignPtr (ATen.C10Ptr ATen.Blob))

newtype Object = UnsafeObject (ForeignPtr (ATen.C10Ptr ATen.IVObject))

newtype Future = UnsafeFuture (ForeignPtr (ATen.C10Ptr ATen.IVFuture))

newtype Capsule = UnsafeCapsule (ForeignPtr (ATen.C10Ptr ATen.Capsule))

-- | See https://github.com/pytorch/pytorch/wiki/PyTorch-IR
newtype Graph = UnsafeGraph (ForeignPtr (ATen.SharedPtr ATen.JitGraph))

data JitGraph = JitGraph
  { JitGraph -> [JitValue]
graphInputs :: [JitValue],
    JitGraph -> [JitValue]
graphOutputs :: [JitValue],
    JitGraph -> [JitNode]
graphNodes :: [JitNode]
  }
  deriving (Int -> JitGraph -> ShowS
[JitGraph] -> ShowS
JitGraph -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JitGraph] -> ShowS
$cshowList :: [JitGraph] -> ShowS
show :: JitGraph -> String
$cshow :: JitGraph -> String
showsPrec :: Int -> JitGraph -> ShowS
$cshowsPrec :: Int -> JitGraph -> ShowS
Show, JitGraph -> JitGraph -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JitGraph -> JitGraph -> Bool
$c/= :: JitGraph -> JitGraph -> Bool
== :: JitGraph -> JitGraph -> Bool
$c== :: JitGraph -> JitGraph -> Bool
Eq)

data JitNode = JitNode
  { JitNode -> [JitValue]
nodeInputs :: [JitValue],
    JitNode -> [JitValue]
nodeOutputs :: [JitValue],
    JitNode -> String
nodeKind :: String
  }
  deriving (Int -> JitNode -> ShowS
[JitNode] -> ShowS
JitNode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JitNode] -> ShowS
$cshowList :: [JitNode] -> ShowS
show :: JitNode -> String
$cshow :: JitNode -> String
showsPrec :: Int -> JitNode -> ShowS
$cshowsPrec :: Int -> JitNode -> ShowS
Show, JitNode -> JitNode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JitNode -> JitNode -> Bool
$c/= :: JitNode -> JitNode -> Bool
== :: JitNode -> JitNode -> Bool
$c== :: JitNode -> JitNode -> Bool
Eq)

data JitValue = JitValue
  { JitValue -> Int
valueId :: Int,
    JitValue -> String
valueType :: String
  }
  deriving (Int -> JitValue -> ShowS
[JitValue] -> ShowS
JitValue -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JitValue] -> ShowS
$cshowList :: [JitValue] -> ShowS
show :: JitValue -> String
$cshow :: JitValue -> String
showsPrec :: Int -> JitValue -> ShowS
$cshowsPrec :: Int -> JitValue -> ShowS
Show, JitValue -> JitValue -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JitValue -> JitValue -> Bool
$c/= :: JitValue -> JitValue -> Bool
== :: JitValue -> JitValue -> Bool
$c== :: JitValue -> JitValue -> Bool
Eq)

instance Show Blob where
  show :: Blob -> String
show Blob
_ = String
"Blob"

instance Show Future where
  show :: Future -> String
show Future
_ = String
"Future"

instance Show Object where
  show :: Object -> String
show Object
_ = String
"Object"

instance Show Capsule where
  show :: Capsule -> String
show Capsule
_ = String
"Capsule"

data IValue
  = IVNone
  | IVTensor Tensor
  | IVDouble Double
  | IVInt Int64
  | IVBool Bool
  | IVTuple [IValue]
  | IVIntList [Int64]
  | IVDoubleList [Double]
  | IVBoolList [Bool]
  | IVString String
  | IVTensorList [Tensor]
  | IVBlob -- Blob
  | IVGenericList [IValue]
  | IVGenericDict [(IValue, IValue)]
  | IVFuture -- Future
  | IVDevice -- Device
  | IVObject -- Object
  | IVUninitialized
  | IVCapsule -- Capsule
  deriving (Int -> IValue -> ShowS
[IValue] -> ShowS
IValue -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IValue] -> ShowS
$cshowList :: [IValue] -> ShowS
show :: IValue -> String
$cshow :: IValue -> String
showsPrec :: Int -> IValue -> ShowS
$cshowsPrec :: Int -> IValue -> ShowS
Show)

instance Castable ScriptModule (ForeignPtr ATen.Module) where
  cast :: forall r. ScriptModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeScriptModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
  uncast :: forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
uncast ForeignPtr Module
obj ScriptModule -> IO r
f = ScriptModule -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
obj

instance Castable RawModule (ForeignPtr ATen.Module) where
  cast :: forall r. RawModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeRawModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
  uncast :: forall r. ForeignPtr Module -> (RawModule -> IO r) -> IO r
uncast ForeignPtr Module
obj RawModule -> IO r
f = RawModule -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> RawModule
UnsafeRawModule ForeignPtr Module
obj

instance Castable Graph (ForeignPtr (ATen.SharedPtr ATen.JitGraph)) where
  cast :: forall r.
Graph -> (ForeignPtr (SharedPtr JitGraph) -> IO r) -> IO r
cast (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj) ForeignPtr (SharedPtr JitGraph) -> IO r
f = ForeignPtr (SharedPtr JitGraph) -> IO r
f ForeignPtr (SharedPtr JitGraph)
obj
  uncast :: forall r.
ForeignPtr (SharedPtr JitGraph) -> (Graph -> IO r) -> IO r
uncast ForeignPtr (SharedPtr JitGraph)
obj Graph -> IO r
f = Graph -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr (SharedPtr JitGraph) -> Graph
UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj

newModule :: String -> IO RawModule
newModule :: String -> IO RawModule
newModule = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr StdString -> IO (ForeignPtr Module)
LibTorch.newModule

saveScript :: ScriptModule -> FilePath -> IO ()
saveScript :: ScriptModule -> String -> IO ()
saveScript = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save

saveScript' :: RawModule -> FilePath -> IO ()
saveScript' :: RawModule -> String -> IO ()
saveScript' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save

data LoadMode
  = WithoutRequiredGrad
  | WithRequiredGrad
  deriving (Int -> LoadMode -> ShowS
[LoadMode] -> ShowS
LoadMode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoadMode] -> ShowS
$cshowList :: [LoadMode] -> ShowS
show :: LoadMode -> String
$cshow :: LoadMode -> String
showsPrec :: Int -> LoadMode -> ShowS
$cshowsPrec :: Int -> LoadMode -> ShowS
Show, LoadMode -> LoadMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LoadMode -> LoadMode -> Bool
$c/= :: LoadMode -> LoadMode -> Bool
== :: LoadMode -> LoadMode -> Bool
$c== :: LoadMode -> LoadMode -> Bool
Eq)

-- | Load a torchscript file
loadScript :: LoadMode -> FilePath -> IO ScriptModule
loadScript :: LoadMode -> String -> IO ScriptModule
loadScript LoadMode
WithoutRequiredGrad String
file = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
loadScript LoadMode
WithRequiredGrad String
file = do
  module' :: RawModule
module'@(UnsafeRawModule ForeignPtr Module
rmodule) <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
  [Tensor]
params <- RawModule -> IO [Tensor]
getParametersIO RawModule
module'
  [IndependentTensor]
paramsWithRequiredGrad <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Tensor]
params Tensor -> IO IndependentTensor
makeIndependent
  RawModule -> [Tensor] -> IO ()
setParameters RawModule
module' (forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
paramsWithRequiredGrad)
  forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
rmodule)

loadScript' :: FilePath -> IO RawModule
loadScript' :: String -> IO RawModule
loadScript' = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load

instance HasForward ScriptModule [IValue] IValue where
  forward :: ScriptModule -> [IValue] -> IValue
forward ScriptModule
module' = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch ScriptModule
module'
  forwardStoch :: ScriptModule -> [IValue] -> IO IValue
forwardStoch = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ScriptModule -> [RawIValue] -> IO RawIValue
forward'
    where
      forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
      forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
forward' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr (StdVector IValue) -> IO RawIValue
LibTorch.forward

registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
LibTorch.registerParameter

registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ()
LibTorch.registerModule

getParameters ::
  -- | module
  ScriptModule ->
  -- | output
  [Tensor]
getParameters :: ScriptModule -> [Tensor]
getParameters = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters

getParametersIO ::
  -- | module
  RawModule ->
  -- | output
  IO [Tensor]
getParametersIO :: RawModule -> IO [Tensor]
getParametersIO = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters

setParameters :: RawModule -> [Tensor] -> IO ()
setParameters :: RawModule -> [Tensor] -> IO ()
setParameters = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters

updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters LoadMode
mode ScriptModule
module' [Tensor]
inputs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  case LoadMode
mode of
    LoadMode
WithoutRequiredGrad -> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
    LoadMode
WithRequiredGrad -> do
      ScriptModule
r <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
      [IndependentTensor]
paramsWithRequiredGrad <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Tensor]
inputs Tensor -> IO IndependentTensor
makeIndependent
      ScriptModule -> [Tensor] -> IO ()
setParameters' ScriptModule
r (forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
paramsWithRequiredGrad)
      forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
r
  where
    setParameters' :: ScriptModule -> [Tensor] -> IO ()
    setParameters' :: ScriptModule -> [Tensor] -> IO ()
setParameters' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters

getNamedParameters ::
  -- | module
  ScriptModule ->
  -- | output
  [(String, Tensor)]
getNamedParameters :: ScriptModule -> [(String, Tensor)]
getNamedParameters (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  [(ForeignPtr StdString, ForeignPtr Tensor)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedParameters ForeignPtr Module
m
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Tensor)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
    (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value forall (m :: * -> *) a. Monad m => a -> m a
return

getNamedBuffers ::
  -- | module
  ScriptModule ->
  -- | output
  [(String, Tensor)]
getNamedBuffers :: ScriptModule -> [(String, Tensor)]
getNamedBuffers (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  [(ForeignPtr StdString, ForeignPtr Tensor)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedBuffers ForeignPtr Module
m
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Tensor)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
    (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value forall (m :: * -> *) a. Monad m => a -> m a
return

-- | Load all attributes including training flags
-- This function returns IVObject type as Tensor type.
-- To get Tensor type, use get getNamedParameters and getNamedBuffers.
getNamedAttributes ::
  -- | module
  ScriptModule ->
  -- | output
  [(String, IValue)]
getNamedAttributes :: ScriptModule -> [(String, IValue)]
getNamedAttributes (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  [(ForeignPtr StdString, RawIValue)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, RawIValue)]
LibTorch.getNamedAttributes ForeignPtr Module
m
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, RawIValue)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, RawIValue
value) ->
    (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
value forall (m :: * -> *) a. Monad m => a -> m a
return

getNamedModules ::
  -- | module
  ScriptModule ->
  -- | output
  [(String, ScriptModule)]
getNamedModules :: ScriptModule -> [(String, ScriptModule)]
getNamedModules (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  [(ForeignPtr StdString, ForeignPtr Module)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedModules ForeignPtr Module
m
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Module)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
    (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value forall (m :: * -> *) a. Monad m => a -> m a
return

getNamedChildren ::
  -- | module
  ScriptModule ->
  -- | output
  [(String, ScriptModule)]
getNamedChildren :: ScriptModule -> [(String, ScriptModule)]
getNamedChildren (UnsafeScriptModule ForeignPtr Module
m) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  [(ForeignPtr StdString, ForeignPtr Module)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedChildren ForeignPtr Module
m
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Module)]
dat forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
    (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key forall (m :: * -> *) a. Monad m => a -> m a
return forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value forall (m :: * -> *) a. Monad m => a -> m a
return

toScriptModule :: RawModule -> IO ScriptModule
toScriptModule :: RawModule -> IO ScriptModule
toScriptModule RawModule
rawModule = do
  (UnsafeRawModule ForeignPtr Module
r) <- RawModule -> IO RawModule
cloneRawModule RawModule
rawModule
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
r

toRawModule :: ScriptModule -> IO RawModule
toRawModule :: ScriptModule -> IO RawModule
toRawModule ScriptModule
scriptModule = do
  (UnsafeScriptModule ForeignPtr Module
r) <- ScriptModule -> IO ScriptModule
clone' ScriptModule
scriptModule
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> RawModule
UnsafeRawModule ForeignPtr Module
r
  where
    clone' :: ScriptModule -> IO ScriptModule
clone' = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone

cloneRawModule :: RawModule -> IO RawModule
cloneRawModule :: RawModule -> IO RawModule
cloneRawModule = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone

data RuntimeMode = Eval | Train deriving (Int -> RuntimeMode -> ShowS
[RuntimeMode] -> ShowS
RuntimeMode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RuntimeMode] -> ShowS
$cshowList :: [RuntimeMode] -> ShowS
show :: RuntimeMode -> String
$cshow :: RuntimeMode -> String
showsPrec :: Int -> RuntimeMode -> ShowS
$cshowsPrec :: Int -> RuntimeMode -> ShowS
Show, RuntimeMode -> RuntimeMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RuntimeMode -> RuntimeMode -> Bool
$c/= :: RuntimeMode -> RuntimeMode -> Bool
== :: RuntimeMode -> RuntimeMode -> Bool
$c== :: RuntimeMode -> RuntimeMode -> Bool
Eq)

setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode RawModule
rmod RuntimeMode
mode = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> CBool -> IO ()
LibTorch.train RawModule
rmod (RuntimeMode
mode forall a. Eq a => a -> a -> Bool
== RuntimeMode
Train)

define :: RawModule -> String -> IO ()
define :: RawModule -> String -> IO ()
define = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr StdString -> IO ()
LibTorch.define

dumpToStr ::
  -- | module
  ScriptModule ->
  -- | print_method_bodies
  Bool ->
  -- | print_attr_values
  Bool ->
  -- | print_param_values
  Bool ->
  -- | ouput
  IO String
dumpToStr :: ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString)
LibTorch.dumpToStr

dumpToStr' :: ScriptModule -> IO String
dumpToStr' :: ScriptModule -> IO String
dumpToStr' ScriptModule
obj = ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr ScriptModule
obj Bool
True Bool
True Bool
True

runMethod ::
  -- | module
  ScriptModule ->
  -- | func
  String ->
  -- | inputs
  [IValue] ->
  -- | output
  IValue
runMethod :: ScriptModule -> String -> [IValue] -> IValue
runMethod ScriptModule
module' String
func = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' ScriptModule
module' String
func
  where
    runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
    runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue)
LibTorch.runMethod

runMethod1 ::
  -- | module
  ScriptModule ->
  -- | func
  String ->
  -- | inputs
  IValue ->
  -- | output
  IValue
runMethod1 :: ScriptModule -> String -> IValue -> IValue
runMethod1 ScriptModule
module' String
func = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' ScriptModule
module' String
func
  where
    runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
    runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> RawIValue -> IO (Ptr IValue)
LibTorch.runMethod1

instance Parameterized ScriptModule where
  flattenParameters :: ScriptModule -> [IndependentTensor]
flattenParameters ScriptModule
module' = forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor forall a b. (a -> b) -> a -> b
$ ScriptModule -> [Tensor]
getParameters ScriptModule
module'
  _replaceParameters :: ScriptModule -> ParamStream ScriptModule
_replaceParameters ScriptModule
module' = do
    let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length (ScriptModule -> [Tensor]
getParameters ScriptModule
module')
    [IndependentTensor]
ps' <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
len ParamStream IndependentTensor
nextParameter
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters LoadMode
WithRequiredGrad ScriptModule
module' (forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
ps')

trace ::
  -- | moduleName
  String ->
  -- | functionName
  String ->
  -- | function
  ([Tensor] -> IO [Tensor]) ->
  -- | inputs
  [Tensor] ->
  -- | output
  IO RawModule
trace :: String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace String
moduleName String
functionName [Tensor] -> IO [Tensor]
func = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 (\String
m String
f ForeignPtr TensorList
inps -> String
-> String
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList
-> IO (ForeignPtr Module)
LibTorch.trace String
m String
f (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func) ForeignPtr TensorList
inps) String
moduleName String
functionName
  where
    trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
    trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
      forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
        [Tensor]
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
        forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
ret forall (m :: * -> *) a. Monad m => a -> m a
return

-- | This function generates torchscript-module from Parameterized-instance of hasktorch.
-- Usage is below.
-- -- >> let example_inputs = asTensor (4::Float)
-- -- >> init_parameters <- sample MonoSpec
-- -- >> mutableTorchscript <- traceWithParameters "MyModule"
-- --                            (\parameters [example_inputs'] -> return [(traced_function parameters example_inputs')])
-- --                            init_parameters
-- --                            [example_inputs]
-- -- >> immutableTorchscript <- toScriptModule mutableTorchscript
-- -- >> save immutableTorchscript "<your torchscript file>"
traceWithParameters ::
  Parameterized f =>
  -- | module name
  String ->
  -- | traced function
  (f -> [Tensor] -> IO [Tensor]) ->
  -- | initial parameters
  f ->
  -- | example inputs
  [Tensor] ->
  -- | torchscript module
  IO RawModule
traceWithParameters :: forall f.
Parameterized f =>
String
-> (f -> [Tensor] -> IO [Tensor]) -> f -> [Tensor] -> IO RawModule
traceWithParameters String
moduleName f -> [Tensor] -> IO [Tensor]
func f
parameterized_parameters [Tensor]
inputs = do
  let parameters :: [Tensor]
parameters = forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent (forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters f
parameterized_parameters)
      fromParams :: [Tensor] -> f
fromParams [Tensor]
params = forall f. Parameterized f => f -> [IndependentTensor] -> f
replaceParameters f
parameterized_parameters (forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor [Tensor]
params)
      plen :: Int
plen = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
parameters
      ilen :: Int
ilen = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
inputs
  RawModule
r <-
    String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace
      String
moduleName
      String
"forwardWithParameters"
      ( \[Tensor]
parametersAndInputs ->
          f -> [Tensor] -> IO [Tensor]
func
            ([Tensor] -> f
fromParams (forall a. Int -> [a] -> [a]
take Int
plen [Tensor]
parametersAndInputs))
            (forall a. Int -> [a] -> [a]
drop Int
plen [Tensor]
parametersAndInputs)
      )
      ([Tensor]
parameters forall a. [a] -> [a] -> [a]
++ [Tensor]
inputs)
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [Tensor]
parameters) forall a b. (a -> b) -> a -> b
$ \(Integer
i, Tensor
p) ->
    RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter RawModule
r (String
"p" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Integer
i) Tensor
p Bool
False
  let args :: String
args = forall a. [a] -> [[a]] -> [a]
intercalate String
", " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"i" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
ilen forall a. Num a => a -> a -> a
-Int
1)]
      params :: String
params = forall a. [a] -> [[a]] -> [a]
intercalate String
", " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"self.p" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
plen forall a. Num a => a -> a -> a
-Int
1)]
  RawModule -> String -> IO ()
define RawModule
r forall a b. (a -> b) -> a -> b
$
    String
"def forward(self, " forall a. [a] -> [a] -> [a]
++ String
args forall a. [a] -> [a] -> [a]
++ String
"):\n" forall a. [a] -> [a] -> [a]
++ String
"    return self.forwardWithParameters(" forall a. [a] -> [a] -> [a]
++ String
params forall a. [a] -> [a] -> [a]
++ String
", " forall a. [a] -> [a] -> [a]
++ String
args forall a. [a] -> [a] -> [a]
++ String
" )\n"
  forall (m :: * -> *) a. Monad m => a -> m a
return RawModule
r

traceAsGraph ::
  -- | function
  ([Tensor] -> IO [Tensor]) ->
  -- | inputs
  [Tensor] ->
  -- | output
  IO Graph
traceAsGraph :: ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO Graph
traceAsGraph [Tensor] -> IO [Tensor]
func = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ((ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
LibTorch.traceAsGraph (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func))
  where
    trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
    trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
      forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
        [Tensor]
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
        forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
ret forall (m :: * -> *) a. Monad m => a -> m a
return

printGraph :: Graph -> IO String
printGraph :: Graph -> IO String
printGraph = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printGraph

-- | Output onnx file from graph. (really experimental implementation)
-- printOnnx uses export_onnx function of libtorch.
-- It outputs following error, because prim::Constant symbol using torchscript does not exist.
-- -- Exception: ONNX export failed: Couldn't export operator prim::Constant
-- -- Defined at:
-- --   Graph we tried to export:
-- --   graph(%0 : Float(),
-- --               %1 : Float()):
-- --     %2 : int = prim::Constant[value=1]()
-- --   %3 : Float() = aten::add(%0, %1, %2)
-- --   return (%3)
-- -- ; type: std::runtime_error
-- On the other hand, torch.onnx.export of python works.
-- onnx's symbol map is in python code.
-- https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py
--
-- If you need onnx-file, at first make torchscript by trace , then convert torchscript into onnx by python-code.
printOnnx :: Graph -> IO String
printOnnx :: Graph -> IO String
printOnnx = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printOnnx

graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
graph) =
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (SharedPtr JitGraph)
graph forall a b. (a -> b) -> a -> b
$ \Ptr (SharedPtr JitGraph)
g0 -> forall a.
Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
Unmanaged.withJitGraph Ptr (SharedPtr JitGraph)
g0 forall a b. (a -> b) -> a -> b
$ \Ptr JitGraph
g -> do
    [JitValue]
graphInputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitValue]
Unmanaged.graphInputs Ptr JitGraph
g
    [JitValue]
graphOutputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitValue]
Unmanaged.graphOutputs Ptr JitGraph
g
    [JitNode]
graphNodes <- forall {t :: * -> *}.
Traversable t =>
t (Ptr JitNode) -> IO (t JitNode)
toJitNode forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitNode]
Unmanaged.graphNodes Ptr JitGraph
g
    forall (f :: * -> *) a. Applicative f => a -> f a
pure JitGraph {[JitValue]
[JitNode]
graphNodes :: [JitNode]
graphOutputs :: [JitValue]
graphInputs :: [JitValue]
graphNodes :: [JitNode]
graphOutputs :: [JitValue]
graphInputs :: [JitValue]
..}
  where
    toJitValue :: t a -> IO (t JitValue)
toJitValue t a
inputs =
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t a
inputs forall a b. (a -> b) -> a -> b
$ \a
i -> do
        Int
valueId <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitValue -> IO CInt
Unmanaged.valueId a
i
        String
valueType <- forall a ca. Castable a ca => IO ca -> IO a
cast0 (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitValue -> IO (Ptr StdString)
Unmanaged.valueType a
i :: IO (ForeignPtr ATen.StdString))
        forall (f :: * -> *) a. Applicative f => a -> f a
pure JitValue {Int
String
valueType :: String
valueId :: Int
valueType :: String
valueId :: Int
..}
    toJitNode :: t (Ptr JitNode) -> IO (t JitNode)
toJitNode t (Ptr JitNode)
nodes =
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t (Ptr JitNode)
nodes forall a b. (a -> b) -> a -> b
$ \Ptr JitNode
n -> do
        [JitValue]
nodeInputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitNode -> IO [Ptr JitValue]
Unmanaged.nodeInputs Ptr JitNode
n
        [JitValue]
nodeOutputs <- forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitNode -> IO [Ptr JitValue]
Unmanaged.nodeOutputs Ptr JitNode
n
        String
nodeKind <- forall a ca. Castable a ca => IO ca -> IO a
cast0 (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitNode -> IO (Ptr StdString)
Unmanaged.nodeKind Ptr JitNode
n :: IO (ForeignPtr ATen.StdString))
        forall (f :: * -> *) a. Applicative f => a -> f a
pure JitNode {String
[JitValue]
nodeKind :: String
nodeOutputs :: [JitValue]
nodeInputs :: [JitValue]
nodeKind :: String
nodeOutputs :: [JitValue]
nodeInputs :: [JitValue]
..}

instance Castable [IValue] [RawIValue] where
  cast :: forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
cast [IValue]
a [RawIValue] -> IO r
f = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [IValue]
a (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` forall (m :: * -> *) a. Monad m => a -> m a
return) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [RawIValue] -> IO r
f
  uncast :: forall r. [RawIValue] -> ([IValue] -> IO r) -> IO r
uncast [RawIValue]
a [IValue] -> IO r
f = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [RawIValue]
a (forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` forall (m :: * -> *) a. Monad m => a -> m a
return) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [IValue] -> IO r
f

instance Castable IValue RawIValue where
  cast :: forall r. IValue -> (RawIValue -> IO r) -> IO r
cast IValue
IVNone RawIValue -> IO r
f = IO RawIValue
newIValue forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
  cast (IVTensor (Unsafe ForeignPtr Tensor
v)) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr Tensor
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
  cast (IVDouble Double
v) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue Double
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
  cast (IVInt Int64
v) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue Int64
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
  cast (IVBool Bool
v) RawIValue -> IO r
f = forall a b. IValueLike a b => a -> IO b
toIValue Bool
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
  cast (IVTuple [IValue]
v) RawIValue -> IO r
f = do
    [RawIValue]
rawIValues <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
    ForeignPtr (C10Ptr IVTuple)
c10tuple <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10Ptr IVTuple)
c10tuple
  cast (IVIntList [Int64]
v) RawIValue -> IO r
f = do
    ForeignPtr (C10List Int64)
v' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Int64]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List Int64))
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List Int64)
v'
  cast (IVDoubleList [Double]
v) RawIValue -> IO r
f = do
    [CDouble]
cdoubles <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Double]
v (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CDouble]
    ForeignPtr (C10List CDouble)
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [CDouble]
cdoubles forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List CDouble))
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List CDouble)
c10list
  cast (IVBoolList [Bool]
v) RawIValue -> IO r
f = do
    [CBool]
cbools <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Bool]
v (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CBool]
    ForeignPtr (C10List CBool)
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [CBool]
cbools forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List CBool))
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List CBool)
c10list
  cast (IVString String
v) RawIValue -> IO r
f = do
    ForeignPtr StdString
v' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast String
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.StdString)
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr StdString
v'
  cast (IVTensorList [Tensor]
v) RawIValue -> IO r
f = do
    ForeignPtr (C10List Tensor)
v' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List Tensor)
v'
  cast (IVGenericList [IValue]
v) RawIValue -> IO r
f = do
    [RawIValue]
rawIValues <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
    ForeignPtr (C10List IValue)
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List ATen.IValue))
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List IValue)
c10list
  cast (IVGenericDict [(IValue, IValue)]
v) RawIValue -> IO r
f = do
    [RawIValue]
keys <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(IValue, IValue)]
v) forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
    [RawIValue]
values <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(IValue, IValue)]
v) forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
    let rawIValues :: [(RawIValue, RawIValue)]
rawIValues = forall a b. [a] -> [b] -> [(a, b)]
zip [RawIValue]
keys [RawIValue]
values
    ForeignPtr (C10Dict '(IValue, IValue))
c10list <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [(RawIValue, RawIValue)]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
    RawIValue -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10Dict '(IValue, IValue))
c10list
  --  cast (IVBlob (UnsafeBlob v)) f = toIValue v >>= f
  --  cast (IVFuture (UnsafeFuture v)) f = toIValue v >>= f
  --  cast (IVDevice v) f = toIValue v >>= f
  --  cast (IVObject (UnsafeObject v)) f = toIValue v >>= f
  --  cast (IVUninitialized) f = f (toIValue v)
  --  cast (IVCapsule v) f = toIValue v >>= f
  cast IValue
a RawIValue -> IO r
f = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"Unsupported data-type:" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show IValue
a
  uncast :: forall r. RawIValue -> (IValue -> IO r) -> IO r
uncast RawIValue
obj IValue -> IO r
f =
    forall {m :: * -> *} {a} {a}.
(MonadThrow m, Eq a, Num a) =>
[(m a, m a)] -> m a
select
      [ (RawIValue -> IO CBool
iValue_isNone RawIValue
obj, IValue -> IO r
f IValue
IVNone),
        (RawIValue -> IO CBool
iValue_isTensor RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IValue
IVTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe),
        (RawIValue -> IO CBool
iValue_isDouble RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> IValue
IVDouble),
        (RawIValue -> IO CBool
iValue_isInt RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IValue
IVInt),
        (RawIValue -> IO CBool
iValue_isBool RawIValue
obj, forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> IValue
IVBool),
        ( RawIValue -> IO CBool
iValue_isString RawIValue
obj,
          do
            ForeignPtr StdString
v <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr ATen.StdString)
            String
str <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
v forall (m :: * -> *) a. Monad m => a -> m a
return :: IO String
            IValue -> IO r
f (String -> IValue
IVString String
str)
        ),
        ( RawIValue -> IO CBool
iValue_isTensorList RawIValue
obj,
          do
            ForeignPtr (C10List Tensor)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
            [Tensor]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [Tensor]
            IValue -> IO r
f ([Tensor] -> IValue
IVTensorList [Tensor]
ts)
        ),
        ( RawIValue -> IO CBool
iValue_isDoubleList RawIValue
obj,
          do
            ForeignPtr (C10List CDouble)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CDouble))
            [CDouble]
cdoubles <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List CDouble)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [CDouble]
            [Double]
doubles <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CDouble]
cdoubles (forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [Double]
            IValue -> IO r
f ([Double] -> IValue
IVDoubleList [Double]
doubles)
        ),
        ( RawIValue -> IO CBool
iValue_isIntList RawIValue
obj,
          do
            ForeignPtr (C10List Int64)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List Int64))
            [Int64]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Int64)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [Int64]
            IValue -> IO r
f ([Int64] -> IValue
IVIntList [Int64]
ts)
        ),
        ( RawIValue -> IO CBool
iValue_isBoolList RawIValue
obj,
          do
            ForeignPtr (C10List CBool)
v' <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CBool))
            [CBool]
cbools <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List CBool)
v' forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [CBool]
            [Bool]
bools <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CBool]
cbools (forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [Bool]
            IValue -> IO r
f ([Bool] -> IValue
IVBoolList [Bool]
bools)
        ),
        ( RawIValue -> IO CBool
iValue_isTuple RawIValue
obj,
          do
            ForeignPtr (C10Ptr IVTuple)
c10tuple <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
            [RawIValue]
rawIValues <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10Ptr IVTuple)
c10tuple forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
            [IValue]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [IValue]
            IValue -> IO r
f ([IValue] -> IValue
IVTuple [IValue]
ts)
        ),
        ( RawIValue -> IO CBool
iValue_isList RawIValue
obj,
          do
            ForeignPtr (C10List IValue)
c10list <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.IValue))
            [RawIValue]
rawIValues <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List IValue)
c10list forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
            [IValue]
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [RawIValue]
rawIValues forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [IValue]
            IValue -> IO r
f ([IValue] -> IValue
IVGenericList [IValue]
ts)
        ),
        ( RawIValue -> IO CBool
iValue_isGenericDict RawIValue
obj,
          do
            ForeignPtr (C10Dict '(IValue, IValue))
c10list <- forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
            [(RawIValue, RawIValue)]
rawIValues <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10Dict '(IValue, IValue))
c10list forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [(RawIValue, RawIValue)]
            [(IValue, IValue)]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(RawIValue, RawIValue)]
rawIValues forall a b. (a -> b) -> a -> b
$ \(RawIValue
a, RawIValue
b) -> do
              IValue
a' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
a forall (m :: * -> *) a. Monad m => a -> m a
return
              IValue
b' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
b forall (m :: * -> *) a. Monad m => a -> m a
return
              forall (m :: * -> *) a. Monad m => a -> m a
return (IValue
a', IValue
b')
            IValue -> IO r
f ([(IValue, IValue)] -> IValue
IVGenericDict [(IValue, IValue)]
ts)
        ),
        (RawIValue -> IO CBool
iValue_isBlob RawIValue
obj, IValue -> IO r
f IValue
IVBlob),
        (RawIValue -> IO CBool
iValue_isFuture RawIValue
obj, IValue -> IO r
f IValue
IVFuture),
        (RawIValue -> IO CBool
iValue_isDevice RawIValue
obj, IValue -> IO r
f IValue
IVDevice),
        (RawIValue -> IO CBool
iValue_isObject RawIValue
obj, IValue -> IO r
f IValue
IVObject),
        (RawIValue -> IO CBool
iValue_isCapsule RawIValue
obj, IValue -> IO r
f IValue
IVCapsule)
      ]
    where
      select :: [(m a, m a)] -> m a
select [] = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Unsupported IValue"
      select ((m a
cond, m a
body) : [(m a, m a)]
xs) =
        m a
cond forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          a
1 -> m a
body
          a
_ -> [(m a, m a)] -> m a
select [(m a, m a)]
xs