{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Class where

import Control.Exception (Exception (..))
import Control.Monad.Catch (MonadThrow (..))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Indexed (IxPointed (ireturn), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.State (MonadState (get, put))
import Data.Bifunctor (Bifunctor (bimap))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Constraint, Type)
import qualified Data.Map.Strict as Map
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Typeable (Typeable)
import qualified Data.Vector as V hiding (uncons)
import qualified Data.Vector.Generic.Sized.Internal as VGS
import qualified Data.Vector.Sized as VS
import Foreign.ForeignPtr (ForeignPtr)
import GHC.Generics (Generic (..), K1 (..), M1 (..), U1 (..), (:*:) (..))
import GHC.TypeLits (Nat, natVal, type (+))
import Torch.GraduallyTyped.Device (Device, DeviceType)
import qualified Torch.GraduallyTyped.Internal.Vector as V
import Torch.GraduallyTyped.Prelude.TypeLits (SNat (..))
import Torch.GraduallyTyped.Random (Generator)
import Torch.GraduallyTyped.Shape.Type (SDim)
import Torch.GraduallyTyped.Tensor.Type (Tensor (..), TensorSpec (..), UncheckedTensor, sCheckedDataType, sCheckedLayout, sCheckedShape, sSetDevice, sSetGradient)
import qualified Torch.Internal.Type as ATen (Tensor)
import qualified Torch.Script (IValue (..))
import qualified Torch.Serialize (pickleLoad, pickleSave)
import qualified Torch.Tensor (Tensor (Unsafe))

type NamedModel :: Type -> Type
data NamedModel model = NamedModel Text model
  deriving stock (NamedModel model -> NamedModel model -> Bool
forall model.
Eq model =>
NamedModel model -> NamedModel model -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NamedModel model -> NamedModel model -> Bool
$c/= :: forall model.
Eq model =>
NamedModel model -> NamedModel model -> Bool
== :: NamedModel model -> NamedModel model -> Bool
$c== :: forall model.
Eq model =>
NamedModel model -> NamedModel model -> Bool
Eq, NamedModel model -> NamedModel model -> Bool
NamedModel model -> NamedModel model -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {model}. Ord model => Eq (NamedModel model)
forall model.
Ord model =>
NamedModel model -> NamedModel model -> Bool
forall model.
Ord model =>
NamedModel model -> NamedModel model -> Ordering
forall model.
Ord model =>
NamedModel model -> NamedModel model -> NamedModel model
min :: NamedModel model -> NamedModel model -> NamedModel model
$cmin :: forall model.
Ord model =>
NamedModel model -> NamedModel model -> NamedModel model
max :: NamedModel model -> NamedModel model -> NamedModel model
$cmax :: forall model.
Ord model =>
NamedModel model -> NamedModel model -> NamedModel model
>= :: NamedModel model -> NamedModel model -> Bool
$c>= :: forall model.
Ord model =>
NamedModel model -> NamedModel model -> Bool
> :: NamedModel model -> NamedModel model -> Bool
$c> :: forall model.
Ord model =>
NamedModel model -> NamedModel model -> Bool
<= :: NamedModel model -> NamedModel model -> Bool
$c<= :: forall model.
Ord model =>
NamedModel model -> NamedModel model -> Bool
< :: NamedModel model -> NamedModel model -> Bool
$c< :: forall model.
Ord model =>
NamedModel model -> NamedModel model -> Bool
compare :: NamedModel model -> NamedModel model -> Ordering
$ccompare :: forall model.
Ord model =>
NamedModel model -> NamedModel model -> Ordering
Ord, Int -> NamedModel model -> ShowS
forall model. Show model => Int -> NamedModel model -> ShowS
forall model. Show model => [NamedModel model] -> ShowS
forall model. Show model => NamedModel model -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NamedModel model] -> ShowS
$cshowList :: forall model. Show model => [NamedModel model] -> ShowS
show :: NamedModel model -> String
$cshow :: forall model. Show model => NamedModel model -> String
showsPrec :: Int -> NamedModel model -> ShowS
$cshowsPrec :: forall model. Show model => Int -> NamedModel model -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall model x. Rep (NamedModel model) x -> NamedModel model
forall model x. NamedModel model -> Rep (NamedModel model) x
$cto :: forall model x. Rep (NamedModel model) x -> NamedModel model
$cfrom :: forall model x. NamedModel model -> Rep (NamedModel model) x
Generic)

pattern (::>) :: Text -> model -> NamedModel model
pattern $b::> :: forall model. StateDictKey -> model -> NamedModel model
$m::> :: forall {r} {model}.
NamedModel model
-> (StateDictKey -> model -> r) -> ((# #) -> r) -> r
(::>) name model = NamedModel name model

type HasForward ::
  Type ->
  Type ->
  Device (DeviceType Nat) ->
  Type ->
  Device (DeviceType Nat) ->
  Constraint
class
  HasForward
    model
    input
    generatorDevice
    output
    generatorOutputDevice
    | model input generatorDevice -> output,
      model input generatorDevice -> generatorOutputDevice
  where
  -- | @forward m i g@ for a model @m@, an input @i@, and a generator @g@
  -- returns the tuple @(o, g')@ where @o@ is the output of the model applied to the input
  -- and @g'@ is the updated generator.
  -- @forward m i g@ may throw an exception if the input @i@ or the generator @g@
  -- are not compatible with the model @m@.
  forward ::
    forall m.
    MonadThrow m =>
    -- | model
    model ->
    -- | model input, typically a tensor or a tuple of tensors
    input ->
    -- | random generator
    Generator generatorDevice ->
    -- | output of the model with an updated generator
    m (output, Generator generatorOutputDevice)
  default forward ::
    forall m.
    ( MonadThrow m,
      Generic model,
      GHasForward (Rep model) input generatorDevice output generatorOutputDevice
    ) =>
    model ->
    input ->
    Generator generatorDevice ->
    m (output, Generator generatorOutputDevice)
  forward model
model = forall (gModel :: * -> *) input
       (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasForward
   gModel input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
gModel c
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
gForward (forall a x. Generic a => a -> Rep a x
from model
model)

instance HasForward () input generatorDevice input generatorDevice where
  forward :: forall (m :: * -> *).
MonadThrow m =>
()
-> input
-> Generator generatorDevice
-> m (input, Generator generatorDevice)
forward ()
_ = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,)

instance
  HasForward model input generatorDevice output generatorOutputDevice =>
  HasForward (NamedModel model) input generatorDevice output generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
NamedModel model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (NamedModel StateDictKey
_ model
model) = forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model

type GHasForward ::
  (Type -> Type) ->
  Type ->
  Device (DeviceType Nat) ->
  Type ->
  Device (DeviceType Nat) ->
  Constraint
class
  GHasForward
    gModel
    input
    generatorDevice
    output
    generatorOutputDevice
    | gModel input generatorDevice -> output,
      gModel input generatorDevice -> generatorOutputDevice
  where
  gForward ::
    forall m c.
    MonadThrow m =>
    gModel c ->
    input ->
    Generator generatorDevice ->
    m (output, Generator generatorOutputDevice)

instance
  ( GHasForward
      gModelA
      inputA
      generatorDevice
      outputA
      generatorOutputADevice,
    GHasForward
      gModelB
      outputA
      generatorOutputADevice
      outputB
      generatorOutputDevice
  ) =>
  GHasForward
    (gModelA :*: gModelB)
    inputA
    generatorDevice
    outputB
    generatorOutputDevice
  where
  gForward :: forall (m :: * -> *) c.
MonadThrow m =>
(:*:) gModelA gModelB c
-> inputA
-> Generator generatorDevice
-> m (outputB, Generator generatorOutputDevice)
gForward (gModelA c
gModelA :*: gModelB c
gModelB) inputA
input =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn inputA
input
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gModel :: * -> *) input
       (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasForward
   gModel input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
gModel c
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
gForward gModelA c
gModelA
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gModel :: * -> *) input
       (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasForward
   gModel input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
gModel c
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
gForward gModelB c
gModelB

instance
  GHasForward gModel input generatorDevice output generatorOutputDevice =>
  GHasForward
    (M1 i t gModel)
    input
    generatorDevice
    output
    generatorOutputDevice
  where
  gForward :: forall (m :: * -> *) c.
MonadThrow m =>
M1 i t gModel c
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
gForward (M1 gModel c
gModel) = forall (gModel :: * -> *) input
       (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasForward
   gModel input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
gModel c
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
gForward gModel c
gModel

instance
  HasForward model input generatorDevice output generatorOutputDevice =>
  GHasForward
    (K1 i model)
    input
    generatorDevice
    output
    generatorOutputDevice
  where
  gForward :: forall (m :: * -> *) c.
MonadThrow m =>
K1 i model c
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
gForward (K1 model
model) = forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model

instance GHasForward U1 input generatorDevice input generatorDevice where
  gForward :: forall (m :: * -> *) c.
MonadThrow m =>
U1 c
-> input
-> Generator generatorDevice
-> m (input, Generator generatorDevice)
gForward U1 c
U1 input
input Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (input
input, Generator generatorDevice
g)

instance
  ( HasForward a input generatorDevice outputA generatorOutputADevice,
    HasForward b outputA generatorOutputADevice output generatorOutputDevice
  ) =>
  HasForward (a, b) input generatorDevice output generatorOutputDevice

instance
  ( HasForward a input generatorDevice outputA generatorOutputADevice,
    HasForward b outputA generatorOutputADevice outputB generatorOutputBDevice,
    HasForward c outputB generatorOutputBDevice output generatorOutputDevice
  ) =>
  HasForward (a, b, c) input generatorDevice output generatorOutputDevice

instance
  ( HasForward a input generatorDevice outputA generatorOutputADevice,
    HasForward b outputA generatorOutputADevice outputB generatorOutputBDevice,
    HasForward c outputB generatorOutputBDevice outputC generatorOutputCDevice,
    HasForward d outputC generatorOutputCDevice output generatorOutputDevice
  ) =>
  HasForward (a, b, c, d) input generatorDevice output generatorOutputDevice

instance
  ( HasForward a input generatorDevice outputA generatorOutputADevice,
    HasForward b outputA generatorOutputADevice outputB generatorOutputBDevice,
    HasForward c outputB generatorOutputBDevice outputC generatorOutputCDevice,
    HasForward d outputC generatorOutputCDevice outputD generatorOutputDDevice,
    HasForward e outputD generatorOutputDDevice output generatorOutputDevice
  ) =>
  HasForward (a, b, c, d, e) input generatorDevice output generatorOutputDevice

instance
  ( HasForward a input generatorDevice outputA generatorOutputADevice,
    HasForward b outputA generatorOutputADevice outputB generatorOutputBDevice,
    HasForward c outputB generatorOutputBDevice outputC generatorOutputCDevice,
    HasForward d outputC generatorOutputCDevice outputD generatorOutputDDevice,
    HasForward e outputD generatorOutputDDevice outputE generatorOutputEDevice,
    HasForward f outputE generatorOutputEDevice output generatorOutputDevice
  ) =>
  HasForward (a, b, c, d, e, f) input generatorDevice output generatorOutputDevice

instance
  ( HasForward a input generatorDevice outputA generatorOutputADevice,
    HasForward b outputA generatorOutputADevice outputB generatorOutputBDevice,
    HasForward c outputB generatorOutputBDevice outputC generatorOutputCDevice,
    HasForward d outputC generatorOutputCDevice outputD generatorOutputDDevice,
    HasForward e outputD generatorOutputDDevice outputE generatorOutputEDevice,
    HasForward f outputE generatorOutputEDevice outputF generatorOutputFDevice,
    HasForward g outputF generatorOutputFDevice output generatorOutputDevice
  ) =>
  HasForward (a, b, c, d, e, f, g) input generatorDevice output generatorOutputDevice

type Wrap :: Type -> Type
newtype Wrap a = Wrap a
  deriving (Wrap a -> Wrap a -> Bool
forall a. Eq a => Wrap a -> Wrap a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Wrap a -> Wrap a -> Bool
$c/= :: forall a. Eq a => Wrap a -> Wrap a -> Bool
== :: Wrap a -> Wrap a -> Bool
$c== :: forall a. Eq a => Wrap a -> Wrap a -> Bool
Eq, Wrap a -> Wrap a -> Bool
Wrap a -> Wrap a -> Ordering
Wrap a -> Wrap a -> Wrap a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {a}. Ord a => Eq (Wrap a)
forall a. Ord a => Wrap a -> Wrap a -> Bool
forall a. Ord a => Wrap a -> Wrap a -> Ordering
forall a. Ord a => Wrap a -> Wrap a -> Wrap a
min :: Wrap a -> Wrap a -> Wrap a
$cmin :: forall a. Ord a => Wrap a -> Wrap a -> Wrap a
max :: Wrap a -> Wrap a -> Wrap a
$cmax :: forall a. Ord a => Wrap a -> Wrap a -> Wrap a
>= :: Wrap a -> Wrap a -> Bool
$c>= :: forall a. Ord a => Wrap a -> Wrap a -> Bool
> :: Wrap a -> Wrap a -> Bool
$c> :: forall a. Ord a => Wrap a -> Wrap a -> Bool
<= :: Wrap a -> Wrap a -> Bool
$c<= :: forall a. Ord a => Wrap a -> Wrap a -> Bool
< :: Wrap a -> Wrap a -> Bool
$c< :: forall a. Ord a => Wrap a -> Wrap a -> Bool
compare :: Wrap a -> Wrap a -> Ordering
$ccompare :: forall a. Ord a => Wrap a -> Wrap a -> Ordering
Ord, Int -> Wrap a -> ShowS
forall a. Show a => Int -> Wrap a -> ShowS
forall a. Show a => [Wrap a] -> ShowS
forall a. Show a => Wrap a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Wrap a] -> ShowS
$cshowList :: forall a. Show a => [Wrap a] -> ShowS
show :: Wrap a -> String
$cshow :: forall a. Show a => Wrap a -> String
showsPrec :: Int -> Wrap a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Wrap a -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Wrap a) x -> Wrap a
forall a x. Wrap a -> Rep (Wrap a) x
$cto :: forall a x. Rep (Wrap a) x -> Wrap a
$cfrom :: forall a x. Wrap a -> Rep (Wrap a) x
Generic)

type instance ModelSpec (Wrap a) = Wrap (ModelSpec a)

instance
  HasInitialize a generatorDevice a' generatorOutputDevice =>
  HasInitialize (Wrap a) generatorDevice (Wrap a') generatorOutputDevice

instance
  HasStateDict a =>
  HasStateDict (Wrap a)

instance
  HasForward a input generatorDevice output generatorOutputDevice =>
  HasForward (Wrap a) input generatorDevice output generatorOutputDevice

type ListToTuple :: [Type] -> Type
type family ListToTuple xs = tuple | tuple -> xs where
  ListToTuple '[] = ()
  ListToTuple '[a] = Wrap a
  ListToTuple '[a, b] = (a, b)
  ListToTuple '[a, b, c] = (a, b, c)
  ListToTuple '[a, b, c, d] = (a, b, c, d)
  ListToTuple '[a, b, c, d, e] = (a, b, c, d, e)
  ListToTuple '[a, b, c, d, e, f] = (a, b, c, d, e, f)
  ListToTuple '[a, b, c, d, e, f, g] = (a, b, c, d, e, f, g)

type ModelStack :: [Type] -> Type
newtype ModelStack models = ModelStack (ListToTuple models)
  deriving stock (forall (models :: [*]) x.
Rep (ModelStack models) x -> ModelStack models
forall (models :: [*]) x.
ModelStack models -> Rep (ModelStack models) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (models :: [*]) x.
Rep (ModelStack models) x -> ModelStack models
$cfrom :: forall (models :: [*]) x.
ModelStack models -> Rep (ModelStack models) x
Generic)

type instance ModelSpec (ModelStack '[]) = ModelStack '[]

type instance ModelSpec (ModelStack '[a]) = ModelStack '[ModelSpec a]

type instance ModelSpec (ModelStack '[a, b]) = ModelStack '[ModelSpec a, ModelSpec b]

type instance ModelSpec (ModelStack '[a, b, c]) = ModelStack '[ModelSpec a, ModelSpec b, ModelSpec c]

type instance ModelSpec (ModelStack '[a, b, c, d]) = ModelStack '[ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d]

type instance ModelSpec (ModelStack '[a, b, c, d, e]) = ModelStack '[ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d, ModelSpec e]

type instance ModelSpec (ModelStack '[a, b, c, d, e, f]) = ModelStack '[ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d, ModelSpec e, ModelSpec f]

type instance ModelSpec (ModelStack '[a, b, c, d, e, f, g]) = ModelStack '[ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d, ModelSpec e, ModelSpec f, ModelSpec g]

instance
  HasForward (ListToTuple models) input generatorDevice output generatorOutputDevice =>
  HasForward (ModelStack models) input generatorDevice output generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
ModelStack models
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (ModelStack ListToTuple models
models) = forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward ListToTuple models
models

instance HasForward (VS.Vector 0 a) input generatorDevice input generatorDevice where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Vector 0 a
-> input
-> Generator generatorDevice
-> m (input, Generator generatorDevice)
forward Vector 0 a
_ = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,)

instance
  HasForward a input generatorDevice output generatorOutputDevice =>
  HasForward (VS.Vector 1 a) input generatorDevice output generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Vector 1 a
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (VGS.Vector Vector a
v) input
input Generator generatorDevice
g =
    let Just (a
a, Vector a
_) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector a
v
     in forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward a
a input
input Generator generatorDevice
g

instance-- {-# OVERLAPPABLE #-}

  ( HasForward a input generatorDevice output generatorOutputDevice,
    HasForward a output generatorOutputDevice output generatorOutputDevice
  ) =>
  HasForward (VS.Vector n a) input generatorDevice output generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Vector n a
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (VGS.Vector Vector a
v) input
input Generator generatorDevice
g =
    let Just (a
a, Vector a
as) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector a
v
     in forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl
          ( \m (output, Generator generatorOutputDevice)
agg a
a' -> do
              (output
output', Generator generatorOutputDevice
g') <- m (output, Generator generatorOutputDevice)
agg
              forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward a
a' output
output' Generator generatorOutputDevice
g'
          )
          (forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward a
a input
input Generator generatorDevice
g)
          Vector a
as

type ModelSpec :: Type -> Type
type family ModelSpec model = spec | spec -> model

type HasInitialize ::
  Type ->
  Device (DeviceType Nat) ->
  Type ->
  Device (DeviceType Nat) ->
  Constraint
class
  HasInitialize
    model
    generatorDevice
    output
    generatorOutputDevice
    | model generatorDevice -> output,
      model generatorDevice -> generatorOutputDevice
  where
  initialize ::
    forall m.
    MonadThrow m =>
    ModelSpec model ->
    Generator generatorDevice ->
    m (output, Generator generatorOutputDevice)
  default initialize ::
    forall m.
    ( MonadThrow m,
      Generic (ModelSpec model),
      Generic output,
      GHasInitialize (Rep (ModelSpec model)) generatorDevice (Rep output) generatorOutputDevice
    ) =>
    ModelSpec model ->
    Generator generatorDevice ->
    m (output, Generator generatorOutputDevice)
  initialize ModelSpec model
modelSpec =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gModelSpec :: * -> *)
       (generatorDevice :: Device (DeviceType Nat)) (gOutput :: * -> *)
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasInitialize
   gModelSpec generatorDevice gOutput generatorOutputDevice,
 MonadThrow m) =>
gModelSpec c
-> Generator generatorDevice
-> m (gOutput c, Generator generatorOutputDevice)
gInitialize (forall a x. Generic a => a -> Rep a x
from ModelSpec model
modelSpec))
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => Rep a x -> a
to

type GHasInitialize ::
  (Type -> Type) ->
  Device (DeviceType Nat) ->
  (Type -> Type) ->
  Device (DeviceType Nat) ->
  Constraint
class
  GHasInitialize
    gModelSpec
    generatorDevice
    gOutput
    generatorOutputDevice
    | gModelSpec generatorDevice -> gOutput,
      gModelSpec generatorDevice -> generatorOutputDevice
  where
  gInitialize ::
    forall m c.
    MonadThrow m =>
    gModelSpec c ->
    Generator generatorDevice ->
    m (gOutput c, Generator generatorOutputDevice)

instance
  ( GHasInitialize gModelSpecA generatorDevice gOutputA generatorOutputADevice,
    GHasInitialize gModelSpecB generatorOutputADevice gOutputB generatorOutputDevice
  ) =>
  GHasInitialize
    (gModelSpecA :*: gModelSpecB)
    generatorDevice
    (gOutputA :*: gOutputB)
    generatorOutputDevice
  where
  gInitialize :: forall (m :: * -> *) c.
MonadThrow m =>
(:*:) gModelSpecA gModelSpecB c
-> Generator generatorDevice
-> m ((:*:) gOutputA gOutputB c, Generator generatorOutputDevice)
gInitialize (gModelSpecA c
gModelSpecA :*: gModelSpecB c
gModelSpecB) =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gModelSpec :: * -> *)
       (generatorDevice :: Device (DeviceType Nat)) (gOutput :: * -> *)
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasInitialize
   gModelSpec generatorDevice gOutput generatorOutputDevice,
 MonadThrow m) =>
gModelSpec c
-> Generator generatorDevice
-> m (gOutput c, Generator generatorOutputDevice)
gInitialize gModelSpecA c
gModelSpecA) forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gModelSpec :: * -> *)
       (generatorDevice :: Device (DeviceType Nat)) (gOutput :: * -> *)
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasInitialize
   gModelSpec generatorDevice gOutput generatorOutputDevice,
 MonadThrow m) =>
gModelSpec c
-> Generator generatorDevice
-> m (gOutput c, Generator generatorOutputDevice)
gInitialize gModelSpecB c
gModelSpecB)

instance
  GHasInitialize gModelSpec generatorDevice gOutput generatorOutputDevice =>
  GHasInitialize
    (M1 i t gModelSpec)
    generatorDevice
    (M1 i t gOutput)
    generatorOutputDevice
  where
  gInitialize :: forall (m :: * -> *) c.
MonadThrow m =>
M1 i t gModelSpec c
-> Generator generatorDevice
-> m (M1 i t gOutput c, Generator generatorOutputDevice)
gInitialize (M1 gModelSpec c
gModelSpec) = forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gModelSpec :: * -> *)
       (generatorDevice :: Device (DeviceType Nat)) (gOutput :: * -> *)
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *) c.
(GHasInitialize
   gModelSpec generatorDevice gOutput generatorOutputDevice,
 MonadThrow m) =>
gModelSpec c
-> Generator generatorDevice
-> m (gOutput c, Generator generatorOutputDevice)
gInitialize gModelSpec c
gModelSpec)

instance
  ( HasInitialize model generatorDevice output generatorOutputDevice,
    ModelSpec model ~ modelSpec
  ) =>
  GHasInitialize
    (K1 i modelSpec)
    generatorDevice
    (K1 i output)
    generatorOutputDevice
  where
  gInitialize :: forall (m :: * -> *) c.
MonadThrow m =>
K1 i modelSpec c
-> Generator generatorDevice
-> m (K1 i output c, Generator generatorOutputDevice)
gInitialize (K1 modelSpec
modelSpec) = forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall k i c (p :: k). c -> K1 i c p
K1 forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall model (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize @model modelSpec
modelSpec)

instance GHasInitialize U1 generatorDevice U1 generatorDevice where
  gInitialize :: forall (m :: * -> *) c.
MonadThrow m =>
U1 c
-> Generator generatorDevice -> m (U1 c, Generator generatorDevice)
gInitialize U1 c
U1 Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall k (p :: k). U1 p
U1, Generator generatorDevice
g)

type instance ModelSpec (SDim dim) = SDim dim

instance HasInitialize (SDim dim) generatorDevice (SDim dim) generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (SDim dim)
-> Generator generatorDevice
-> m (SDim dim, Generator generatorDevice)
initialize ModelSpec (SDim dim)
dim Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (SDim dim)
dim, Generator generatorDevice
g)

type instance ModelSpec () = ()

instance HasInitialize () generatorDevice () generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec ()
-> Generator generatorDevice -> m ((), Generator generatorDevice)
initialize () Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure ((), Generator generatorDevice
g)

type instance ModelSpec (NamedModel model) = NamedModel (ModelSpec model)

instance
  HasInitialize model generatorDevice output generatorOutputDevice =>
  HasInitialize (NamedModel model) generatorDevice (NamedModel output) generatorOutputDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (NamedModel model)
-> Generator generatorDevice
-> m (NamedModel output, Generator generatorOutputDevice)
initialize (NamedModel StateDictKey
modelName ModelSpec model
modelSpec) Generator generatorDevice
g = do
    (output
model, Generator generatorOutputDevice
g') <- forall model (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize ModelSpec model
modelSpec Generator generatorDevice
g
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
modelName output
model, Generator generatorOutputDevice
g')

type instance ModelSpec (a, b) = (ModelSpec a, ModelSpec b)

type instance ModelSpec (a, b, c) = (ModelSpec a, ModelSpec b, ModelSpec c)

type instance ModelSpec (a, b, c, d) = (ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d)

type instance ModelSpec (a, b, c, d, e) = (ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d, ModelSpec e)

type instance ModelSpec (a, b, c, d, e, f) = (ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d, ModelSpec e, ModelSpec f)

type instance ModelSpec (a, b, c, d, e, f, g) = (ModelSpec a, ModelSpec b, ModelSpec c, ModelSpec d, ModelSpec e, ModelSpec f, ModelSpec g)

instance
  ( HasInitialize a generatorDevice outputA generatorOutputADevice,
    HasInitialize b generatorOutputADevice outputB generatorOutputDevice
  ) =>
  HasInitialize (a, b) generatorDevice (outputA, outputB) generatorOutputDevice

instance
  ( HasInitialize a generatorDevice outputA generatorOutputADevice,
    HasInitialize b generatorOutputADevice outputB generatorOutputBDevice,
    HasInitialize c generatorOutputBDevice outputC generatorOutputDevice
  ) =>
  HasInitialize (a, b, c) generatorDevice (outputA, outputB, outputC) generatorOutputDevice

instance
  ( HasInitialize a generatorDevice outputA generatorOutputADevice,
    HasInitialize b generatorOutputADevice outputB generatorOutputBDevice,
    HasInitialize c generatorOutputBDevice outputC generatorOutputCDevice,
    HasInitialize d generatorOutputCDevice outputD generatorOutputDevice
  ) =>
  HasInitialize (a, b, c, d) generatorDevice (outputA, outputB, outputC, outputD) generatorOutputDevice

instance
  ( HasInitialize a generatorDevice outputA generatorOutputADevice,
    HasInitialize b generatorOutputADevice outputB generatorOutputBDevice,
    HasInitialize c generatorOutputBDevice outputC generatorOutputCDevice,
    HasInitialize d generatorOutputCDevice outputD generatorOutputDDevice,
    HasInitialize e generatorOutputDDevice outputE generatorOutputDevice
  ) =>
  HasInitialize (a, b, c, d, e) generatorDevice (outputA, outputB, outputC, outputD, outputE) generatorOutputDevice

instance
  ( HasInitialize a generatorDevice outputA generatorOutputADevice,
    HasInitialize b generatorOutputADevice outputB generatorOutputBDevice,
    HasInitialize c generatorOutputBDevice outputC generatorOutputCDevice,
    HasInitialize d generatorOutputCDevice outputD generatorOutputDDevice,
    HasInitialize e generatorOutputDDevice outputE generatorOutputEDevice,
    HasInitialize f generatorOutputEDevice outputF generatorOutputDevice
  ) =>
  HasInitialize (a, b, c, d, e, f) generatorDevice (outputA, outputB, outputC, outputD, outputE, outputF) generatorOutputDevice

instance
  ( HasInitialize a generatorDevice outputA generatorOutputADevice,
    HasInitialize b generatorOutputADevice outputB generatorOutputBDevice,
    HasInitialize c generatorOutputBDevice outputC generatorOutputCDevice,
    HasInitialize d generatorOutputCDevice outputD generatorOutputDDevice,
    HasInitialize e generatorOutputDDevice outputE generatorOutputEDevice,
    HasInitialize f generatorOutputEDevice outputF generatorOutputFDevice,
    HasInitialize g generatorOutputFDevice outputG generatorOutputDevice
  ) =>
  HasInitialize (a, b, c, d, e, f, g) generatorDevice (outputA, outputB, outputC, outputD, outputE, outputF, outputG) generatorOutputDevice

instance HasInitialize (ModelStack '[]) generatorDevice (ModelStack '[]) generatorDevice

instance
  HasInitialize a generatorDevice a' generatorOutputDevice =>
  HasInitialize (ModelStack '[a]) generatorDevice (ModelStack '[a']) generatorOutputDevice

instance
  HasInitialize (a, b) generatorDevice (a', b') generatorOutputDevice =>
  HasInitialize (ModelStack '[a, b]) generatorDevice (ModelStack '[a', b']) generatorOutputDevice

instance
  HasInitialize (a, b, c) generatorDevice (a', b', c') generatorOutputDevice =>
  HasInitialize (ModelStack '[a, b, c]) generatorDevice (ModelStack '[a', b', c']) generatorOutputDevice

instance
  HasInitialize (a, b, c, d) generatorDevice (a', b', c', d') generatorOutputDevice =>
  HasInitialize (ModelStack '[a, b, c, d]) generatorDevice (ModelStack '[a', b', c', d']) generatorOutputDevice

instance
  HasInitialize (a, b, c, d, e) generatorDevice (a', b', c', d', e') generatorOutputDevice =>
  HasInitialize (ModelStack '[a, b, c, d, e]) generatorDevice (ModelStack '[a', b', c', d', e']) generatorOutputDevice

instance
  HasInitialize (a, b, c, d, e, f) generatorDevice (a', b', c', d', e', f') generatorOutputDevice =>
  HasInitialize (ModelStack '[a, b, c, d, e, f]) generatorDevice (ModelStack '[a', b', c', d', e', f']) generatorOutputDevice

instance
  HasInitialize (a, b, c, d, e, f, g) generatorDevice (a', b', c', d', e', f', g') generatorOutputDevice =>
  HasInitialize (ModelStack '[a, b, c, d, e, f, g]) generatorDevice (ModelStack '[a', b', c', d', e', f', g']) generatorOutputDevice

data VectorSpec (n :: Nat) (a :: Type) where
  VectorSpec ::
    forall n a.
    SNat n ->
    VS.Vector n (ModelSpec a) ->
    VectorSpec n a

deriving stock instance Show (ModelSpec a) => Show (VectorSpec n a)

type instance ModelSpec (VS.Vector n a) = VectorSpec n a

instance
  ( HasInitialize a generatorDevice output generatorOutputDevice,
    HasInitialize a generatorOutputDevice output generatorOutputDevice,
    n' ~ (n + 1)
  ) =>
  HasInitialize (VS.Vector n' a) generatorDevice (VS.Vector n' output) generatorOutputDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (Vector n' a)
-> Generator generatorDevice
-> m (Vector n' output, Generator generatorOutputDevice)
initialize (VectorSpec SNat n'
SNat (VGS.Vector Vector (ModelSpec a)
specs)) Generator generatorDevice
g = do
    let Just (ModelSpec a
spec, Vector (ModelSpec a)
specs') = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector (ModelSpec a)
specs
    (output
a, Generator generatorOutputDevice
g') <- forall model (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize ModelSpec a
spec Generator generatorDevice
g
    (Vector output
as, Generator generatorOutputDevice
g'''') <-
      forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl
        ( \m (Vector output, Generator generatorOutputDevice)
agg ModelSpec a
spec' -> do
            (Vector output
acc, Generator generatorOutputDevice
g'') <- m (Vector output, Generator generatorOutputDevice)
agg
            (output
a', Generator generatorOutputDevice
g''') <- forall model (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize ModelSpec a
spec' Generator generatorOutputDevice
g''
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Vector a -> a -> Vector a
V.snoc Vector output
acc output
a', Generator generatorOutputDevice
g''')
        )
        (forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Vector a
V.singleton output
a, Generator generatorOutputDevice
g'))
        Vector (ModelSpec a)
specs'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (v :: * -> *) (n :: Nat) a. v a -> Vector v n a
VGS.Vector Vector output
as, Generator generatorOutputDevice
g'''')

type StateDictKey = Text

type StateDict = Map.Map StateDictKey (ForeignPtr ATen.Tensor)

newtype FromStateDictError = FromStateDictKeyNotFoundError {FromStateDictError -> StateDictKey
fsdeExpectedKey :: StateDictKey}
  deriving stock (Int -> FromStateDictError -> ShowS
[FromStateDictError] -> ShowS
FromStateDictError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FromStateDictError] -> ShowS
$cshowList :: [FromStateDictError] -> ShowS
show :: FromStateDictError -> String
$cshow :: FromStateDictError -> String
showsPrec :: Int -> FromStateDictError -> ShowS
$cshowsPrec :: Int -> FromStateDictError -> ShowS
Show, Typeable)

instance Exception FromStateDictError where
  displayException :: FromStateDictError -> String
displayException FromStateDictKeyNotFoundError {StateDictKey
fsdeExpectedKey :: StateDictKey
fsdeExpectedKey :: FromStateDictError -> StateDictKey
..} = String
"`" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show StateDictKey
fsdeExpectedKey forall a. Semigroup a => a -> a -> a
<> String
"` is not in the model's state dictionary."

newtype ToStateDictError = ToStateDictKeyAlreadyInUseError {ToStateDictError -> StateDictKey
fsdeTakenKey :: StateDictKey}
  deriving stock (Int -> ToStateDictError -> ShowS
[ToStateDictError] -> ShowS
ToStateDictError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ToStateDictError] -> ShowS
$cshowList :: [ToStateDictError] -> ShowS
show :: ToStateDictError -> String
$cshow :: ToStateDictError -> String
showsPrec :: Int -> ToStateDictError -> ShowS
$cshowsPrec :: Int -> ToStateDictError -> ShowS
Show, Typeable)

instance Exception ToStateDictError where
  displayException :: ToStateDictError -> String
displayException ToStateDictKeyAlreadyInUseError {StateDictKey
fsdeTakenKey :: StateDictKey
fsdeTakenKey :: ToStateDictError -> StateDictKey
..} = String
"`" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show StateDictKey
fsdeTakenKey forall a. Semigroup a => a -> a -> a
<> String
"` is already in the model's state dictionary."

type HasStateDict :: Type -> Constraint
class HasStateDict model where
  fromStateDict ::
    forall m.
    (MonadIO m, MonadThrow m, MonadState StateDict m) =>
    ModelSpec model ->
    StateDictKey ->
    m model
  default fromStateDict ::
    forall m.
    ( MonadIO m,
      MonadThrow m,
      MonadState StateDict m,
      Generic model,
      Generic (ModelSpec model),
      GHasStateDict (Rep model) (Rep (ModelSpec model))
    ) =>
    ModelSpec model ->
    StateDictKey ->
    m model
  fromStateDict ModelSpec model
modelSpec StateDictKey
k = forall a x. Generic a => Rep a x -> a
to forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
gModelSpec c -> StateDictKey -> m (gModel c)
gFromStateDict (forall a x. Generic a => a -> Rep a x
from ModelSpec model
modelSpec) StateDictKey
k

  toStateDict ::
    forall m.
    (MonadThrow m, MonadState StateDict m) =>
    StateDictKey ->
    model ->
    m ()
  default toStateDict ::
    forall m.
    ( MonadThrow m,
      MonadState StateDict m,
      Generic model,
      GHasStateDict (Rep model) (Rep (ModelSpec model))
    ) =>
    StateDictKey ->
    model ->
    m ()
  toStateDict StateDictKey
k model
model = forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadThrow m,
 MonadState StateDict m) =>
StateDictKey -> gModel c -> m ()
gToStateDict StateDictKey
k (forall a x. Generic a => a -> Rep a x
from model
model)

type GHasStateDict :: (Type -> Type) -> (Type -> Type) -> Constraint
class GHasStateDict gModel gModelSpec | gModelSpec -> gModel, gModel -> gModelSpec where
  gFromStateDict ::
    forall m c.
    (MonadIO m, MonadThrow m, MonadState StateDict m) =>
    gModelSpec c ->
    StateDictKey ->
    m (gModel c)
  gToStateDict ::
    forall m c.
    (MonadThrow m, MonadState StateDict m) =>
    StateDictKey ->
    gModel c ->
    m ()

instance
  (GHasStateDict gModelA gModelSpecA, GHasStateDict gModelB gModelSpecB) =>
  GHasStateDict
    (gModelA :*: gModelB)
    (gModelSpecA :*: gModelSpecB)
  where
  gFromStateDict :: forall (m :: * -> *) c.
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
(:*:) gModelSpecA gModelSpecB c
-> StateDictKey -> m ((:*:) gModelA gModelB c)
gFromStateDict (gModelSpecA c
gModelASpec :*: gModelSpecB c
gModelBSpec) StateDictKey
k =
    forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
gModelSpec c -> StateDictKey -> m (gModel c)
gFromStateDict gModelSpecA c
gModelASpec StateDictKey
k forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
gModelSpec c -> StateDictKey -> m (gModel c)
gFromStateDict gModelSpecB c
gModelBSpec StateDictKey
k
  gToStateDict :: forall (m :: * -> *) c.
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> (:*:) gModelA gModelB c -> m ()
gToStateDict StateDictKey
k (gModelA c
gModelA :*: gModelB c
gModelB) = do
    () <- forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadThrow m,
 MonadState StateDict m) =>
StateDictKey -> gModel c -> m ()
gToStateDict StateDictKey
k gModelA c
gModelA
    () <- forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadThrow m,
 MonadState StateDict m) =>
StateDictKey -> gModel c -> m ()
gToStateDict StateDictKey
k gModelB c
gModelB
    forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  GHasStateDict gModel gModelSpec =>
  GHasStateDict (M1 i t gModel) (M1 i t gModelSpec)
  where
  gFromStateDict :: forall (m :: * -> *) c.
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
M1 i t gModelSpec c -> StateDictKey -> m (M1 i t gModel c)
gFromStateDict (M1 gModelSpec c
gModelSpec) StateDictKey
k =
    forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
gModelSpec c -> StateDictKey -> m (gModel c)
gFromStateDict gModelSpec c
gModelSpec StateDictKey
k
  gToStateDict :: forall (m :: * -> *) c.
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> M1 i t gModel c -> m ()
gToStateDict StateDictKey
k (M1 gModel c
gModel) =
    forall (gModel :: * -> *) (gModelSpec :: * -> *) (m :: * -> *) c.
(GHasStateDict gModel gModelSpec, MonadThrow m,
 MonadState StateDict m) =>
StateDictKey -> gModel c -> m ()
gToStateDict StateDictKey
k gModel c
gModel

instance
  (HasStateDict model, modelSpec ~ ModelSpec model) =>
  GHasStateDict (K1 i model) (K1 i modelSpec)
  where
  gFromStateDict :: forall (m :: * -> *) c.
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
K1 i modelSpec c -> StateDictKey -> m (K1 i model c)
gFromStateDict (K1 modelSpec
modelSpec) StateDictKey
k =
    forall k i c (p :: k). c -> K1 i c p
K1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict modelSpec
modelSpec StateDictKey
k
  gToStateDict :: forall (m :: * -> *) c.
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> K1 i model c -> m ()
gToStateDict StateDictKey
k (K1 model
model) =
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict StateDictKey
k model
model

instance GHasStateDict U1 U1 where
  gFromStateDict :: forall (m :: * -> *) c.
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
U1 c -> StateDictKey -> m (U1 c)
gFromStateDict U1 c
U1 StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall k (p :: k). U1 p
U1
  gToStateDict :: forall (m :: * -> *) c.
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> U1 c -> m ()
gToStateDict StateDictKey
_ U1 c
U1 = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance HasStateDict (SDim dim) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (SDim dim) -> StateDictKey -> m (SDim dim)
fromStateDict ModelSpec (SDim dim)
dim StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (SDim dim)
dim
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> SDim dim -> m ()
toStateDict StateDictKey
_ SDim dim
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance HasStateDict () where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec () -> StateDictKey -> m ()
fromStateDict () StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> () -> m ()
toStateDict StateDictKey
_ () = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance HasStateDict model => HasStateDict (NamedModel model) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (NamedModel model)
-> StateDictKey -> m (NamedModel model)
fromStateDict (NamedModel StateDictKey
modelName ModelSpec model
modelSpec) StateDictKey
key =
    forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
modelName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict ModelSpec model
modelSpec (StateDictKey
key forall a. Semigroup a => a -> a -> a
<> StateDictKey
modelName)
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> NamedModel model -> m ()
toStateDict StateDictKey
key (NamedModel StateDictKey
modelName model
model) =
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict (StateDictKey
key forall a. Semigroup a => a -> a -> a
<> StateDictKey
modelName) model
model

instance
  ( HasStateDict a,
    HasStateDict b
  ) =>
  HasStateDict (a, b)

instance
  ( HasStateDict a,
    HasStateDict b,
    HasStateDict c
  ) =>
  HasStateDict (a, b, c)

instance
  ( HasStateDict a,
    HasStateDict b,
    HasStateDict c,
    HasStateDict d
  ) =>
  HasStateDict (a, b, c, d)

instance
  ( HasStateDict a,
    HasStateDict b,
    HasStateDict c,
    HasStateDict d,
    HasStateDict e
  ) =>
  HasStateDict (a, b, c, d, e)

instance
  ( HasStateDict a,
    HasStateDict b,
    HasStateDict c,
    HasStateDict d,
    HasStateDict e,
    HasStateDict f
  ) =>
  HasStateDict (a, b, c, d, e, f)

instance
  ( HasStateDict a,
    HasStateDict b,
    HasStateDict c,
    HasStateDict d,
    HasStateDict e,
    HasStateDict f,
    HasStateDict g
  ) =>
  HasStateDict (a, b, c, d, e, f, g)

instance HasStateDict (ModelStack '[])

instance
  HasStateDict a =>
  HasStateDict (ModelStack '[a])

instance
  HasStateDict (a, b) =>
  HasStateDict (ModelStack '[a, b])

instance
  HasStateDict (a, b, c) =>
  HasStateDict (ModelStack '[a, b, c])

instance
  HasStateDict (a, b, c, d) =>
  HasStateDict (ModelStack '[a, b, c, d])

instance
  HasStateDict (a, b, c, d, e) =>
  HasStateDict (ModelStack '[a, b, c, d, e])

instance
  HasStateDict (a, b, c, d, e, f) =>
  HasStateDict (ModelStack '[a, b, c, d, e, f])

instance
  HasStateDict (a, b, c, d, e, f, g) =>
  HasStateDict (ModelStack '[a, b, c, d, e, f, g])

type instance ModelSpec (Tensor gradient layout device dataType shape) = TensorSpec gradient layout device dataType shape

instance
  HasStateDict
    (Tensor gradient layout device dataType shape)
  where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (Tensor gradient layout device dataType shape)
-> StateDictKey -> m (Tensor gradient layout device dataType shape)
fromStateDict (TensorSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SShape shape
shape) StateDictKey
k = do
    StateDict
stateDict <- forall s (m :: * -> *). MonadState s m => m s
get
    forall b a. b -> (a -> b) -> Maybe a -> b
maybe
      (forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateDictKey -> FromStateDictError
FromStateDictKeyNotFoundError forall a b. (a -> b) -> a -> b
$ StateDictKey
k)
      (\ForeignPtr Tensor
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor ForeignPtr Tensor
t :: UncheckedTensor))
      (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup StateDictKey
k StateDict
stateDict)
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (gradient' :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> Tensor gradient' layout device dataType shape
-> IO (Tensor gradient layout device dataType shape)
sSetGradient SGradient gradient
gradient
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (layout' :: Layout LayoutType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
SLayout layout'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout' device dataType shape)
sCheckedLayout SLayout layout
layout
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDevice device
-> Tensor gradient layout device' dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetDevice SDevice device
device
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (dataType' :: DataType DType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDataType dataType, MonadThrow m,
 Catch (dataType <+> dataType')) =>
SDataType dataType'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType' shape)
sCheckedDataType SDataType dataType
dataType
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape SShape shape
shape
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey
-> Tensor gradient layout device dataType shape -> m ()
toStateDict StateDictKey
k (UnsafeTensor ForeignPtr Tensor
t) = do
    StateDict
stateDict <- forall s (m :: * -> *). MonadState s m => m s
get
    StateDict
stateDict' <-
      forall b a. b -> (a -> b) -> Maybe a -> b
maybe
        (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert StateDictKey
k ForeignPtr Tensor
t StateDict
stateDict)
        (\ForeignPtr Tensor
_ -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateDictKey -> ToStateDictError
ToStateDictKeyAlreadyInUseError forall a b. (a -> b) -> a -> b
$ StateDictKey
k)
        (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup StateDictKey
k StateDict
stateDict)
    forall s (m :: * -> *). MonadState s m => s -> m ()
put StateDict
stateDict'

instance
  HasStateDict a =>
  HasStateDict (VS.Vector n a)
  where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (Vector n a) -> StateDictKey -> m (Vector n a)
fromStateDict (VectorSpec SNat n
SNat Vector n (ModelSpec a)
specs) StateDictKey
k = do
    let Int
i :: Int = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy :: Proxy n))
        fromStateDict' :: (ModelSpec a, Int) -> m a
fromStateDict' (ModelSpec a
spec, Int
i') = forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict ModelSpec a
spec (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> String -> StateDictKey
Text.pack (forall a. Show a => a -> String
show Int
i') forall a. Semigroup a => a -> a -> a
<> StateDictKey
".")
    forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (ModelSpec a, Int) -> m a
fromStateDict' forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) a b. Vector n a -> Vector n b -> Vector n (a, b)
VS.zip Vector n (ModelSpec a)
specs (forall (v :: * -> *) (n :: Nat) a. v a -> Vector v n a
VGS.Vector forall a b. (a -> b) -> a -> b
$ forall a. [a] -> Vector a
V.fromList [Int
0 .. Int
i forall a. Num a => a -> a -> a
- Int
1])
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> Vector n a -> m ()
toStateDict StateDictKey
k (VGS.Vector Vector a
v) = do
    let toStateDict' :: (Int, a) -> m ()
toStateDict' (Int
i', a
a) = forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> String -> StateDictKey
Text.pack (forall a. Show a => a -> String
show Int
i') forall a. Semigroup a => a -> a -> a
<> StateDictKey
".") a
a
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int, a) -> m ()
toStateDict' forall a b. (a -> b) -> a -> b
$ forall a b. Vector a -> Vector b -> Vector (a, b)
V.zip (forall a. [a] -> Vector a
V.fromList [Int
0 .. forall a. Vector a -> Int
V.length Vector a
v forall a. Num a => a -> a -> a
- Int
1]) Vector a
v

-- | Load a state dictionary from a TorchScript file.
stateDictFromFile ::
  FilePath ->
  IO StateDict
stateDictFromFile :: String -> IO StateDict
stateDictFromFile String
filePath = do
  IValue
iValue <- String -> IO IValue
Torch.Serialize.pickleLoad String
filePath
  case IValue
iValue of
    Torch.Script.IVGenericDict [(IValue, IValue)]
xs -> forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {f :: * -> *}.
MonadFail f =>
[(IValue, IValue)] -> f [(StateDictKey, ForeignPtr Tensor)]
go [(IValue, IValue)]
xs
    IValue
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"iValue is not a tensor dictionary."
  where
    go :: [(IValue, IValue)] -> f [(StateDictKey, ForeignPtr Tensor)]
go [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    go ((Torch.Script.IVString String
s, Torch.Script.IVTensor (Torch.Tensor.Unsafe ForeignPtr Tensor
t)) : [(IValue, IValue)]
xs) = ((String -> StateDictKey
Text.pack String
s, ForeignPtr Tensor
t) forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(IValue, IValue)] -> f [(StateDictKey, ForeignPtr Tensor)]
go [(IValue, IValue)]
xs
    go ((IValue
_, Torch.Script.IVTensor Tensor
_) : [(IValue, IValue)]
_) = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"iValue is not a string."
    go ((Torch.Script.IVString String
_, IValue
_) : [(IValue, IValue)]
_) = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"iValue is not a tensor."
    go [(IValue, IValue)]
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"iValue is neither a string nor a tensor."

-- | Save a state dictionary to a TorchScript file.
stateDictToFile ::
  StateDict ->
  FilePath ->
  IO ()
stateDictToFile :: StateDict -> String -> IO ()
stateDictToFile StateDict
stateDict String
filePath = do
  let iValue :: IValue
iValue =
        [(IValue, IValue)] -> IValue
Torch.Script.IVGenericDict forall a b. (a -> b) -> a -> b
$
          forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap
            (String -> IValue
Torch.Script.IVString forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateDictKey -> String
Text.unpack)
            (Tensor -> IValue
Torch.Script.IVTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Torch.Tensor.Unsafe)
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Map k a -> [(k, a)]
Map.toList StateDict
stateDict
  IValue -> String -> IO ()
Torch.Serialize.pickleSave IValue
iValue String
filePath