{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Sparse where
import Control.Monad.Indexed (IxPointed (ireturn), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Data.Data (Proxy (..))
import GHC.Generics (Generic)
import GHC.TypeLits (KnownNat, Nat, Symbol, natVal)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType, SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.NN.Functional.Sparse (EmbeddingF, embedding)
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (..))
import Torch.GraduallyTyped.Random (SGetGeneratorDevice)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape (Dim (..), Name, SDim (..), SShape (..), Shape (..), Size)
import Torch.GraduallyTyped.Tensor.Creation (sRandn)
import Torch.GraduallyTyped.Tensor.Type (SGetLayout, Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
data
Embedding
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat)
where
Embedding ::
forall gradient layout device dataType embedNumDim embedDim paddingIdx.
{ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: Tensor gradient layout device dataType ('Shape '[embedNumDim, embedDim])
} ->
Embedding gradient layout device dataType embedNumDim embedDim paddingIdx
deriving stock (Int
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Int
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
[Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showList :: [Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
$cshowList :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
[Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
show :: Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
$cshow :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showsPrec :: Int
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
$cshowsPrec :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Int
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
Rep
(Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
(Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
$cto :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
Rep
(Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
$cfrom :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
(Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
Generic)
data
EmbeddingSpec
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat)
where
EmbeddingSpec ::
forall gradient layout device dataType embedNumDim embedDim paddingIdx.
SGradient gradient ->
SLayout layout ->
SDevice device ->
SDataType dataType ->
SDim embedNumDim ->
SDim embedDim ->
SMaybe paddingIdx ->
EmbeddingSpec gradient layout device dataType embedNumDim embedDim paddingIdx
deriving stock (Int
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Int
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
[EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showList :: [EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
$cshowList :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
[EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx]
-> ShowS
show :: EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
$cshow :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> String
showsPrec :: Int
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
$cshowsPrec :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Int
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
Rep
(EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
(EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
$cto :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
Rep
(EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
$cfrom :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat) x.
EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
-> Rep
(EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx)
x
Generic)
type instance ModelSpec (Embedding gradient layout device dataType embedNumDim embedDim paddingIdx) = EmbeddingSpec gradient layout device dataType embedNumDim embedDim paddingIdx
instance
( output ~ Embedding gradient layout (device <+> generatorDevice) dataType embedNumDim embedDim paddingIdx,
generatorOutputDevice ~ (device <+> generatorDevice),
SGetGeneratorDevice generatorDevice
) =>
HasInitialize
(Embedding gradient layout device dataType embedNumDim embedDim paddingIdx)
generatorDevice
output
generatorOutputDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
(Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize (EmbeddingSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SDim embedNumDim
embedNumDim SDim embedDim
embedDim SMaybe paddingIdx
SNothing) =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim embedNumDim
embedNumDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
embedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
Embedding
initialize (EmbeddingSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SDim embedNumDim
embedNumDim SDim embedDim
embedDim (SJust Sing n
_)) =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim embedNumDim
embedNumDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
embedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
Embedding
instance
HasStateDict
(Embedding gradient layout device dataType embedNumDim embedDim paddingIdx)
where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec
(Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx)
-> StateDictKey
-> m (Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx)
fromStateDict (EmbeddingSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SDim embedNumDim
embedNumDim SDim embedDim
embedDim SMaybe paddingIdx
_paddingIdx) StateDictKey
k =
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
Embedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim embedNumDim
embedNumDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
embedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)) (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight")
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey
-> Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> m ()
toStateDict StateDictKey
k Embedding {Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
..} =
forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight") Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight
instance
( SGetLayout layout,
Catch (dataType' <+> 'DataType 'Int64),
output
~ Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
dataType
(EmbeddingF ('Shape '[embedNumDim, embedDim]) shape')
) =>
HasForward
(Embedding gradient layout device dataType embedNumDim embedDim 'Nothing)
(Tensor gradient' layout' device' dataType' shape')
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
Embedding
gradient layout device dataType embedNumDim embedDim 'Nothing
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward (Embedding Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
weight) Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
Maybe Nat
-> Bool
-> Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
dataType
(EmbeddingF shape shape')
embedding forall a. Maybe a
Nothing Bool
False Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
weight Tensor gradient' layout' device' dataType' shape'
input,)
instance
( SGetLayout layout,
KnownNat paddingIdx,
Catch (dataType' <+> 'DataType 'Int64),
output
~ Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
dataType
(EmbeddingF ('Shape '[embedNumDim, embedDim]) shape')
) =>
HasForward
(Embedding gradient layout device dataType embedNumDim embedDim ('Just paddingIdx))
(Tensor gradient' layout' device' dataType' shape')
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
Embedding
gradient
layout
device
dataType
embedNumDim
embedDim
('Just paddingIdx)
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward Embedding {Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
Embedding
gradient layout device dataType embedNumDim embedDim paddingIdx
-> Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
Maybe Nat
-> Bool
-> Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
dataType
(EmbeddingF shape shape')
embedding (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @paddingIdx) Bool
False Tensor
gradient layout device dataType ('Shape '[embedNumDim, embedDim])
embeddingWeight Tensor gradient' layout' device' dataType' shape'
input,)