{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Sparse where

import Control.Monad.Indexed (IxPointed (ireturn), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Data.Data (Proxy (..))
import GHC.Generics (Generic)
import GHC.TypeLits (KnownNat, Nat, Symbol, natVal)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType, SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.NN.Functional.Sparse (EmbeddingF, embedding)
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (..))
import Torch.GraduallyTyped.Random (SGetGeneratorDevice)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape (Dim (..), Name, SDim (..), SShape (..), Shape (..), Size)
import Torch.GraduallyTyped.Tensor.Creation (sRandn)
import Torch.GraduallyTyped.Tensor.Type (SGetLayout, Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

data
  Embedding
    (gradient :: Gradient RequiresGradient)
    (layout :: Layout LayoutType)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (embedNumDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (paddingIdx :: Maybe Nat)
  where
  Embedding ::
    forall gradient layout device dataType embedNumDim embedDim paddingIdx.
    { forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> Tensor
     gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: Tensor gradient layout device dataType ('Shape '[embedNumDim, embedDim])
    } ->
    Embedding gradient layout device dataType embedNumDim embedDim paddingIdx
  deriving stock (Int
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Int
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
[Embedding
   gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showList :: [Embedding
   gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
$cshowList :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
[Embedding
   gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
show :: Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
$cshow :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showsPrec :: Int
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
$cshowsPrec :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Int
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
Rep
  (Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx)
  x
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
     (Embedding
        gradient layout device dataType embedNumDim embedDim paddingIdx)
     x
$cto :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
Rep
  (Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx)
  x
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
$cfrom :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
     (Embedding
        gradient layout device dataType embedNumDim embedDim paddingIdx)
     x
Generic)

data
  EmbeddingSpec
    (gradient :: Gradient RequiresGradient)
    (layout :: Layout LayoutType)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (embedNumDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (paddingIdx :: Maybe Nat)
  where
  EmbeddingSpec ::
    forall gradient layout device dataType embedNumDim embedDim paddingIdx.
    SGradient gradient ->
    SLayout layout ->
    SDevice device ->
    SDataType dataType ->
    SDim embedNumDim ->
    SDim embedDim ->
    SMaybe paddingIdx ->
    EmbeddingSpec gradient layout device dataType embedNumDim embedDim paddingIdx
  deriving stock (Int
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Int
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
[EmbeddingSpec
   gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
EmbeddingSpec
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showList :: [EmbeddingSpec
   gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
$cshowList :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
[EmbeddingSpec
   gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
show :: EmbeddingSpec
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
$cshow :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
EmbeddingSpec
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showsPrec :: Int
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
$cshowsPrec :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Int
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
Rep
  (EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx)
  x
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
EmbeddingSpec
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
     (EmbeddingSpec
        gradient layout device dataType embedNumDim embedDim paddingIdx)
     x
$cto :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
Rep
  (EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx)
  x
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
$cfrom :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat) x.
EmbeddingSpec
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
     (EmbeddingSpec
        gradient layout device dataType embedNumDim embedDim paddingIdx)
     x
Generic)

type instance ModelSpec (Embedding gradient layout device dataType embedNumDim embedDim paddingIdx) = EmbeddingSpec gradient layout device dataType embedNumDim embedDim paddingIdx

instance
  ( output ~ Embedding gradient layout (device <+> generatorDevice) dataType embedNumDim embedDim paddingIdx,
    generatorOutputDevice ~ (device <+> generatorDevice),
    SGetGeneratorDevice generatorDevice
  ) =>
  HasInitialize
    (Embedding gradient layout device dataType embedNumDim embedDim paddingIdx)
    generatorDevice
    output
    generatorOutputDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
  (Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize (EmbeddingSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SDim embedNumDim
embedNumDim SDim embedDim
embedDim SMaybe paddingIdx
SNothing) =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim embedNumDim
embedNumDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
embedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil))
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
Embedding
  initialize (EmbeddingSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SDim embedNumDim
embedNumDim SDim embedDim
embedDim (SJust Sing n
_)) =
    -- TODO: padding embedding vector may need to be set to zeros
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim embedNumDim
embedNumDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
embedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil))
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
Embedding

instance
  HasStateDict
    (Embedding gradient layout device dataType embedNumDim embedDim paddingIdx)
  where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec
  (Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx)
-> StateDictKey
-> m (Embedding
        gradient layout device dataType embedNumDim embedDim paddingIdx)
fromStateDict (EmbeddingSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SDim embedNumDim
embedNumDim SDim embedDim
embedDim SMaybe paddingIdx
_paddingIdx) StateDictKey
k =
    forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
Embedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim embedNumDim
embedNumDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
embedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)) (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight")
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey
-> Embedding
     gradient layout device dataType embedNumDim embedDim paddingIdx
-> m ()
toStateDict StateDictKey
k Embedding {Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> Tensor
     gradient layout device dataType ('Shape '[embedNumDim, embedDim])
..} =
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight") Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight

instance
  ( SGetLayout layout,
    Catch (dataType' <+> 'DataType 'Int64),
    output
      ~ Tensor
          (gradient <|> gradient')
          (layout <+> layout')
          (device <+> device')
          dataType
          (EmbeddingF ('Shape '[embedNumDim, embedDim]) shape')
  ) =>
  HasForward
    (Embedding gradient layout device dataType embedNumDim embedDim 'Nothing)
    (Tensor gradient' layout' device' dataType' shape')
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Embedding
  gradient layout device dataType embedNumDim embedDim 'Nothing
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward (Embedding Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
weight) Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
Maybe Nat
-> Bool
-> Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
     (gradient <|> gradient')
     (layout <+> layout')
     (device <+> device')
     dataType
     (EmbeddingF shape shape')
embedding forall a. Maybe a
Nothing Bool
False Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
weight Tensor gradient' layout' device' dataType' shape'
input,)

instance
  ( SGetLayout layout,
    KnownNat paddingIdx,
    Catch (dataType' <+> 'DataType 'Int64),
    output
      ~ Tensor
          (gradient <|> gradient')
          (layout <+> layout')
          (device <+> device')
          dataType
          (EmbeddingF ('Shape '[embedNumDim, embedDim]) shape')
  ) =>
  HasForward
    (Embedding gradient layout device dataType embedNumDim embedDim ('Just paddingIdx))
    (Tensor gradient' layout' device' dataType' shape')
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Embedding
  gradient
  layout
  device
  dataType
  embedNumDim
  embedDim
  ('Just paddingIdx)
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward Embedding {Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
Embedding
  gradient layout device dataType embedNumDim embedDim paddingIdx
-> Tensor
     gradient layout device dataType ('Shape '[embedNumDim, embedDim])
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
Maybe Nat
-> Bool
-> Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
     (gradient <|> gradient')
     (layout <+> layout')
     (device <+> device')
     dataType
     (EmbeddingF shape shape')
embedding (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @paddingIdx) Bool
False Tensor
  gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight Tensor gradient' layout' device' dataType' shape'
input,)