hasktorch-gradually-typed-0.2.0.0: experimental project for hasktorch
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Synopsis

Documentation

data MultiHeadAttentionHasScaling Source #

Data type for representing whether or not (and, if so, where) scaling is applied in the multi-headed attention layer.

Constructors

MultiHeadAttentionWithoutScaling

Scaling is not done.

MultiHeadAttentionWithQueryScaling

Scaling is applied to the query after in the in-projection.

MultiHeadAttentionWithWeightScaling

Scaling is applied to the attention weights.

Instances

Instances details
Generic MultiHeadAttentionHasScaling Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Associated Types

type Rep MultiHeadAttentionHasScaling :: Type -> Type Source #

Show MultiHeadAttentionHasScaling Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Eq MultiHeadAttentionHasScaling Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Ord MultiHeadAttentionHasScaling Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

HasStateDict MultiHeadAttentionHasScaling Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

HasInitialize MultiHeadAttentionHasScaling generatorDevice MultiHeadAttentionHasScaling generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

type Rep MultiHeadAttentionHasScaling Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

type Rep MultiHeadAttentionHasScaling = D1 ('MetaData "MultiHeadAttentionHasScaling" "Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MultiHeadAttentionWithoutScaling" 'PrefixI 'False) (U1 :: Type -> Type) :+: (C1 ('MetaCons "MultiHeadAttentionWithQueryScaling" 'PrefixI 'False) (U1 :: Type -> Type) :+: C1 ('MetaCons "MultiHeadAttentionWithWeightScaling" 'PrefixI 'False) (U1 :: Type -> Type)))
type ModelSpec MultiHeadAttentionHasScaling Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

data GMultiHeadAttention (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (qInProj :: Type) (kInProj :: Type) (vInProj :: Type) (outProj :: Type) (dropout :: Type) where Source #

Generic multi-headed attention layer.

  • headDim is the dimension of the attention heads.
  • headEmbedDim is the dimension of the attention head embedding.
  • embedDim is the dimension of the embedding.
  • qInProj is the type of the query projection.
  • kInProj is the type of the key projection.
  • vInProj is the type of the value projection.
  • outProj is the type of the output projection.
  • dropout is the type of the dropout layer.

Constructors

GMultiHeadAttention 

Fields

Instances

Instances details
Generic (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Associated Types

type Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) :: Type -> Type Source #

Methods

from :: GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) x Source #

to :: Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) x -> GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout Source #

(Show qInProj, Show kInProj, Show vInProj, Show outProj, Show dropout) => Show (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Methods

showsPrec :: Int -> GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> ShowS Source #

show :: GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> String Source #

showList :: [GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout] -> ShowS Source #

(HasStateDict qInProj, HasStateDict vInProj, HasStateDict kInProj, HasStateDict outProj, HasStateDict dropout) => HasStateDict (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Methods

fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) -> StateDictKey -> m (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source #

toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> m () Source #

(HasInitialize qInProj generatorDevice qInProj' generatorDevice0, HasInitialize kInProj generatorDevice0 kInProj' generatorDevice1, HasInitialize vInProj generatorDevice1 vInProj' generatorDevice2, HasInitialize outProj generatorDevice2 outProj' generatorDevice3, HasInitialize dropout generatorDevice3 dropout' generatorOutputDevice) => HasInitialize (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) generatorDevice (GMultiHeadAttention headDim headEmbedDim embedDim qInProj' kInProj' vInProj' outProj' dropout') generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Methods

initialize :: MonadThrow m => ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) -> Generator generatorDevice -> m (GMultiHeadAttention headDim headEmbedDim embedDim qInProj' kInProj' vInProj' outProj' dropout', Generator generatorOutputDevice) Source #

(HasForward qInProj (Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape) generatorDevice (Tensor qRequiresGradient qLayout qDevice qDataType qShape0) qGeneratorOutputDevice, reshapedQShape0 ~ ReshapeF qShape0 ('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]), Catch reshapedQShape0, qShape ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) reshapedQShape0, Catch qShape, HasForward kInProj (Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape) qGeneratorOutputDevice (Tensor qRequiresGradient kLayout kDevice kDataType kShape0) kGeneratorOutputDevice, reshapedKShape0 ~ ReshapeF kShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]), Catch reshapedKShape0, transposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) reshapedKShape0, Catch transposedReshapedKShape0, doubleTransposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 2 :: By Symbol Natural)) ('SelectDim ('ByIndex 3 :: By Symbol Natural)) transposedReshapedKShape0, Catch doubleTransposedReshapedKShape0, multipliedQDoubleTransposedReshapedKShape0 ~ MatmulF qShape doubleTransposedReshapedKShape0, Catch multipliedQDoubleTransposedReshapedKShape0, weightsShape0 ~ SoftmaxF ('SelectDim ('ByIndex 3 :: By Symbol Natural)) (BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape), Catch (BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape), Catch weightsShape0, HasForward dropout (Tensor (qRequiresGradient <|> attentionBiasRequiresGradient) (qLayout <+> (kLayout <+> attentionBiasLayout)) (qDevice <+> (kDevice <+> attentionBiasDevice)) (qDataType <+> (kDataType <+> attentionBiasDataType)) weightsShape0) kGeneratorOutputDevice (Tensor weightsRequiresGradient weightsLayout weightsDevice weightsDataType weightsShape) weightsGeneratorOutputDevice, HasForward vInProj (Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape) weightsGeneratorOutputDevice (Tensor weightsRequiresGradient vLayout vDevice vDataType vShape0) vGeneratorOutputDevice, reshapedVShape0 ~ ReshapeF vShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]), Catch reshapedVShape0, transposedReshapedVShape ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) reshapedVShape0, Catch transposedReshapedVShape, multipliedWeightsTransposedReshapedVShape ~ MatmulF weightsShape transposedReshapedVShape, Catch multipliedWeightsTransposedReshapedVShape, outputQueryShape0 ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) multipliedWeightsTransposedReshapedVShape, Catch outputQueryShape0, HasForward outProj (Tensor weightsRequiresGradient (weightsLayout <+> vLayout) (weightsDevice <+> vDevice) (weightsDataType <+> vDataType) reshapedOutputQueryShape0) vGeneratorOutputDevice output generatorOutputDevice, reshapedOutputQueryShape0 ~ ReshapeF outputQueryShape0 ('Shape '[batchDim, querySeqDim, embedDim]), Catch reshapedOutputQueryShape0, SGetShape queryShape, SGetShape keyShape, SGetShape valueShape, batchDim ~ BatchDim queryShape keyShape valueShape, querySeqDim ~ QuerySeqDim queryShape, keySeqDim ~ KeySeqDim keyShape valueShape) => HasForward (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) (Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape, Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape, Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape, Tensor attentionBiasRequiresGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape) generatorDevice output generatorOutputDevice Source #

HasForward instance for GMultiHeadAttention.

┌───────────────┐        ┌───────┐       ┌─────┐       ┌───────┐
│ attentionBias │        │ query │       │ key │       │ value │
└───────┬───────┘        └───┬───┘       └──┬──┘       └───┬───┘
        │                    │              │              │
        │                    ▼              ▼              ▼
        │                mhaQInProj     mhaKInProj     mhaVInProj
        │                    ▼              │              │
        │                (scaling)          │              │
        │                    ▼              ▼              ▼
        │                 reshape        reshape        reshape
        │                    ▼              ▼              ▼
        │                transpose      transpose      transpose
        │                    │              ▼              │
        │                    │          transpose          │
        │                    │              │              │
        │                    └───►matmul◄───┘              │
        │                           ▼                      │
        │                       (scaling)                  │
        │                           │                      │
        └──────────►add◄────────────┘                      │
                     ▼                                     │
                  softmax                                  │
                     ▼                                     │
                 mhaDropout                                │
                     │                                     │
                     └──────────────►matmul◄───────────────┘
                                       ▼
                                   transpose
                                       ▼
                                    reshape
                                       ▼
                                   mhaOutProj
                                       │
                                       ▼
                                   ┌───────┐
                                   │ query │
                                   └───────┘
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

Methods

forward :: MonadThrow m => GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> (Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape, Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape, Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape, Tensor attentionBiasRequiresGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

type Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

type Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) = D1 ('MetaData "GMultiHeadAttention" "Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GMultiHeadAttention" 'PrefixI 'True) (((S1 ('MetaSel ('Just "mhaHeadDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim headDim)) :*: S1 ('MetaSel ('Just "mhaHeadEmbedDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim headEmbedDim))) :*: (S1 ('MetaSel ('Just "mhaEmbedDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim embedDim)) :*: S1 ('MetaSel ('Just "mhaQInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 qInProj))) :*: ((S1 ('MetaSel ('Just "mhaKInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 kInProj) :*: S1 ('MetaSel ('Just "mhaVInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 vInProj)) :*: (S1 ('MetaSel ('Just "mhaOutProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outProj) :*: (S1 ('MetaSel ('Just "mhaDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 dropout) :*: S1 ('MetaSel ('Just "mhaScaling") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 MultiHeadAttentionHasScaling))))))
type ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention

type ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) = GMultiHeadAttention headDim headEmbedDim embedDim (ModelSpec qInProj) (ModelSpec kInProj) (ModelSpec vInProj) (ModelSpec outProj) (ModelSpec dropout)

type family GMultiHeadAttentionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (valueEmbedDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ... Source #

Equations

GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout = GMultiHeadAttention headDim headEmbedDim embedDim (QInProjF style gradient device dataType queryEmbedDim embedDim) (KInProjF style gradient device dataType keyEmbedDim embedDim) (VInProjF style gradient device dataType valueEmbedDim embedDim) (OutProjF style gradient device dataType embedDim queryEmbedDim) (DropoutF style hasDropout) 

type family QInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the linear transformation of the query.

Equations

QInProjF 'T5 gradient device dataType queryEmbedDim embedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim embedDim) 
QInProjF 'ByT5 gradient device dataType queryEmbedDim embedDim = QInProjF 'T5 gradient device dataType queryEmbedDim embedDim 
QInProjF _ gradient device dataType queryEmbedDim embedDim = NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim embedDim) 

type family KInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the linear transformation of the key.

Equations

KInProjF 'T5 gradient device dataType keyEmbedDim embedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType keyEmbedDim embedDim) 
KInProjF 'ByT5 gradient device dataType keyEmbedDim embedDim = KInProjF 'T5 gradient device dataType keyEmbedDim embedDim 
KInProjF _ gradient device dataType keyEmbedDim embedDim = NamedModel (GLinearF 'WithBias gradient device dataType keyEmbedDim embedDim) 

type family VInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (valueEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the linear transformation of the value.

Equations

VInProjF 'T5 gradient device dataType valueEmbedDim embedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType valueEmbedDim embedDim) 
VInProjF 'ByT5 gradient device dataType valueEmbedDim embedDim = VInProjF 'T5 gradient device dataType valueEmbedDim embedDim 
VInProjF _ gradient device dataType valueEmbedDim embedDim = NamedModel (GLinearF 'WithBias gradient device dataType valueEmbedDim embedDim) 

type family OutProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the type of the out-projection layer.

Equations

OutProjF 'T5 gradient device dataType embedDim queryEmbedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType embedDim queryEmbedDim) 
OutProjF 'ByT5 gradient device dataType embedDim queryEmbedDim = OutProjF 'T5 gradient device dataType embedDim queryEmbedDim 
OutProjF _ gradient device dataType embedDim queryEmbedDim = NamedModel (GLinearF 'WithBias gradient device dataType embedDim queryEmbedDim) 

type family DropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ... Source #

Specifies the type of the dropout layer.

multiHeadAttentionSpec :: forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim keyEmbedDim -> SDim valueEmbedDim -> SHasDropout hasDropout -> Double -> ModelSpec (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout) Source #

Specifies the parameters of a multi-headed attention layer.

  • style: the style of the attention layer, e.g. ST5, ByT5, etc.
  • gradient: whether to compute the gradient of the attention layer.
  • device: the computational device on which to allocate the attention layer.
  • dataType: the data type of the attention layer.
  • headDim: the dimension of the attention heads.
  • headEmbedDim: the dimension of the attention head embeddings.
  • embedDim: the dimension of the input embeddings.
  • queryEmbedDim: the dimension of the query embeddings.
  • keyEmbedDim: the dimension of the key embeddings.
  • valueEmbedDim: the dimension of the value embeddings.
  • dropoutP: the dropout rate.

type BatchDim queryShape keyShape valueShape = (queryShape ! 0) <+> ((keyShape ! 0) <+> (valueShape ! 0)) Source #

getBatchDim :: forall m queryShape keyShape valueShape batchDim. (MonadThrow m, batchDim ~ BatchDim queryShape keyShape valueShape) => SShape queryShape -> SShape keyShape -> SShape valueShape -> m (SDim batchDim) Source #

type QuerySeqDim queryShape = queryShape ! 1 Source #

getQuerySeqDim :: forall m queryShape querySeqDim. (MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) => SShape queryShape -> m (SDim querySeqDim) Source #

type KeySeqDim keyShape valueShape = (keyShape ! 1) <+> (valueShape ! 1) Source #

getKeySeqDim :: forall m keyShape valueShape keySeqDim. (MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) => SShape keyShape -> SShape valueShape -> m (SDim keySeqDim) Source #