Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Synopsis
- data MultiHeadAttentionHasScaling
- data GMultiHeadAttention (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (qInProj :: Type) (kInProj :: Type) (vInProj :: Type) (outProj :: Type) (dropout :: Type) where
- GMultiHeadAttention :: forall headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout. {..} -> GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout
- type family GMultiHeadAttentionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (valueEmbedDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ...
- type family QInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- type family KInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- type family VInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (valueEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- type family OutProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- type family DropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ...
- multiHeadAttentionSpec :: forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim keyEmbedDim -> SDim valueEmbedDim -> SHasDropout hasDropout -> Double -> ModelSpec (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout)
- type BatchDim queryShape keyShape valueShape = (queryShape ! 0) <+> ((keyShape ! 0) <+> (valueShape ! 0))
- getBatchDim :: forall m queryShape keyShape valueShape batchDim. (MonadThrow m, batchDim ~ BatchDim queryShape keyShape valueShape) => SShape queryShape -> SShape keyShape -> SShape valueShape -> m (SDim batchDim)
- type QuerySeqDim queryShape = queryShape ! 1
- getQuerySeqDim :: forall m queryShape querySeqDim. (MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) => SShape queryShape -> m (SDim querySeqDim)
- type KeySeqDim keyShape valueShape = (keyShape ! 1) <+> (valueShape ! 1)
- getKeySeqDim :: forall m keyShape valueShape keySeqDim. (MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) => SShape keyShape -> SShape valueShape -> m (SDim keySeqDim)
Documentation
data MultiHeadAttentionHasScaling Source #
Data type for representing whether or not (and, if so, where) scaling is applied in the multi-headed attention layer.
MultiHeadAttentionWithoutScaling | Scaling is not done. |
MultiHeadAttentionWithQueryScaling | Scaling is applied to the query after in the in-projection. |
MultiHeadAttentionWithWeightScaling | Scaling is applied to the attention weights. |
Instances
data GMultiHeadAttention (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (qInProj :: Type) (kInProj :: Type) (vInProj :: Type) (outProj :: Type) (dropout :: Type) where Source #
Generic multi-headed attention layer.
headDim
is the dimension of the attention heads.headEmbedDim
is the dimension of the attention head embedding.embedDim
is the dimension of the embedding.qInProj
is the type of the query projection.kInProj
is the type of the key projection.vInProj
is the type of the value projection.outProj
is the type of the output projection.dropout
is the type of the dropout layer.
GMultiHeadAttention | |
|
Instances
Generic (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention type Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) :: Type -> Type Source # from :: GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) x Source # to :: Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) x -> GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout Source # | |
(Show qInProj, Show kInProj, Show vInProj, Show outProj, Show dropout) => Show (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention showsPrec :: Int -> GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> ShowS Source # show :: GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> String Source # showList :: [GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout] -> ShowS Source # | |
(HasStateDict qInProj, HasStateDict vInProj, HasStateDict kInProj, HasStateDict outProj, HasStateDict dropout) => HasStateDict (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) -> StateDictKey -> m (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> m () Source # | |
(HasInitialize qInProj generatorDevice qInProj' generatorDevice0, HasInitialize kInProj generatorDevice0 kInProj' generatorDevice1, HasInitialize vInProj generatorDevice1 vInProj' generatorDevice2, HasInitialize outProj generatorDevice2 outProj' generatorDevice3, HasInitialize dropout generatorDevice3 dropout' generatorOutputDevice) => HasInitialize (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) generatorDevice (GMultiHeadAttention headDim headEmbedDim embedDim qInProj' kInProj' vInProj' outProj' dropout') generatorOutputDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention initialize :: MonadThrow m => ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) -> Generator generatorDevice -> m (GMultiHeadAttention headDim headEmbedDim embedDim qInProj' kInProj' vInProj' outProj' dropout', Generator generatorOutputDevice) Source # | |
(HasForward qInProj (Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape) generatorDevice (Tensor qRequiresGradient qLayout qDevice qDataType qShape0) qGeneratorOutputDevice, reshapedQShape0 ~ ReshapeF qShape0 ('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]), Catch reshapedQShape0, qShape ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) reshapedQShape0, Catch qShape, HasForward kInProj (Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape) qGeneratorOutputDevice (Tensor qRequiresGradient kLayout kDevice kDataType kShape0) kGeneratorOutputDevice, reshapedKShape0 ~ ReshapeF kShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]), Catch reshapedKShape0, transposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) reshapedKShape0, Catch transposedReshapedKShape0, doubleTransposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 2 :: By Symbol Natural)) ('SelectDim ('ByIndex 3 :: By Symbol Natural)) transposedReshapedKShape0, Catch doubleTransposedReshapedKShape0, multipliedQDoubleTransposedReshapedKShape0 ~ MatmulF qShape doubleTransposedReshapedKShape0, Catch multipliedQDoubleTransposedReshapedKShape0, weightsShape0 ~ SoftmaxF ('SelectDim ('ByIndex 3 :: By Symbol Natural)) (BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape), Catch (BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape), Catch weightsShape0, HasForward dropout (Tensor (qRequiresGradient <|> attentionBiasRequiresGradient) (qLayout <+> (kLayout <+> attentionBiasLayout)) (qDevice <+> (kDevice <+> attentionBiasDevice)) (qDataType <+> (kDataType <+> attentionBiasDataType)) weightsShape0) kGeneratorOutputDevice (Tensor weightsRequiresGradient weightsLayout weightsDevice weightsDataType weightsShape) weightsGeneratorOutputDevice, HasForward vInProj (Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape) weightsGeneratorOutputDevice (Tensor weightsRequiresGradient vLayout vDevice vDataType vShape0) vGeneratorOutputDevice, reshapedVShape0 ~ ReshapeF vShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]), Catch reshapedVShape0, transposedReshapedVShape ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) reshapedVShape0, Catch transposedReshapedVShape, multipliedWeightsTransposedReshapedVShape ~ MatmulF weightsShape transposedReshapedVShape, Catch multipliedWeightsTransposedReshapedVShape, outputQueryShape0 ~ TransposeF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) ('SelectDim ('ByIndex 2 :: By Symbol Natural)) multipliedWeightsTransposedReshapedVShape, Catch outputQueryShape0, HasForward outProj (Tensor weightsRequiresGradient (weightsLayout <+> vLayout) (weightsDevice <+> vDevice) (weightsDataType <+> vDataType) reshapedOutputQueryShape0) vGeneratorOutputDevice output generatorOutputDevice, reshapedOutputQueryShape0 ~ ReshapeF outputQueryShape0 ('Shape '[batchDim, querySeqDim, embedDim]), Catch reshapedOutputQueryShape0, SGetShape queryShape, SGetShape keyShape, SGetShape valueShape, batchDim ~ BatchDim queryShape keyShape valueShape, querySeqDim ~ QuerySeqDim queryShape, keySeqDim ~ KeySeqDim keyShape valueShape) => HasForward (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) (Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape, Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape, Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape, Tensor attentionBiasRequiresGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape) generatorDevice output generatorOutputDevice Source # |
┌───────────────┐ ┌───────┐ ┌─────┐ ┌───────┐ │ attentionBias │ │ query │ │ key │ │ value │ └───────┬───────┘ └───┬───┘ └──┬──┘ └───┬───┘ │ │ │ │ │ ▼ ▼ ▼ │ mhaQInProj mhaKInProj mhaVInProj │ ▼ │ │ │ (scaling) │ │ │ ▼ ▼ ▼ │ reshape reshape reshape │ ▼ ▼ ▼ │ transpose transpose transpose │ │ ▼ │ │ │ transpose │ │ │ │ │ │ └───►matmul◄───┘ │ │ ▼ │ │ (scaling) │ │ │ │ └──────────►add◄────────────┘ │ ▼ │ softmax │ ▼ │ mhaDropout │ │ │ └──────────────►matmul◄───────────────┘ ▼ transpose ▼ reshape ▼ mhaOutProj │ ▼ ┌───────┐ │ query │ └───────┘ |
Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention forward :: MonadThrow m => GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout -> (Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape, Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape, Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape, Tensor attentionBiasRequiresGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source # | |
type Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention type Rep (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) = D1 ('MetaData "GMultiHeadAttention" "Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GMultiHeadAttention" 'PrefixI 'True) (((S1 ('MetaSel ('Just "mhaHeadDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim headDim)) :*: S1 ('MetaSel ('Just "mhaHeadEmbedDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim headEmbedDim))) :*: (S1 ('MetaSel ('Just "mhaEmbedDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim embedDim)) :*: S1 ('MetaSel ('Just "mhaQInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 qInProj))) :*: ((S1 ('MetaSel ('Just "mhaKInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 kInProj) :*: S1 ('MetaSel ('Just "mhaVInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 vInProj)) :*: (S1 ('MetaSel ('Just "mhaOutProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outProj) :*: (S1 ('MetaSel ('Just "mhaDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 dropout) :*: S1 ('MetaSel ('Just "mhaScaling") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 MultiHeadAttentionHasScaling)))))) | |
type ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention type ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) = GMultiHeadAttention headDim headEmbedDim embedDim (ModelSpec qInProj) (ModelSpec kInProj) (ModelSpec vInProj) (ModelSpec outProj) (ModelSpec dropout) |
type family GMultiHeadAttentionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (valueEmbedDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ... Source #
GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout = GMultiHeadAttention headDim headEmbedDim embedDim (QInProjF style gradient device dataType queryEmbedDim embedDim) (KInProjF style gradient device dataType keyEmbedDim embedDim) (VInProjF style gradient device dataType valueEmbedDim embedDim) (OutProjF style gradient device dataType embedDim queryEmbedDim) (DropoutF style hasDropout) |
type family QInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the linear transformation of the query.
QInProjF 'T5 gradient device dataType queryEmbedDim embedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim embedDim) | |
QInProjF 'ByT5 gradient device dataType queryEmbedDim embedDim = QInProjF 'T5 gradient device dataType queryEmbedDim embedDim | |
QInProjF _ gradient device dataType queryEmbedDim embedDim = NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim embedDim) |
type family KInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the linear transformation of the key.
KInProjF 'T5 gradient device dataType keyEmbedDim embedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType keyEmbedDim embedDim) | |
KInProjF 'ByT5 gradient device dataType keyEmbedDim embedDim = KInProjF 'T5 gradient device dataType keyEmbedDim embedDim | |
KInProjF _ gradient device dataType keyEmbedDim embedDim = NamedModel (GLinearF 'WithBias gradient device dataType keyEmbedDim embedDim) |
type family VInProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (valueEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the linear transformation of the value.
VInProjF 'T5 gradient device dataType valueEmbedDim embedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType valueEmbedDim embedDim) | |
VInProjF 'ByT5 gradient device dataType valueEmbedDim embedDim = VInProjF 'T5 gradient device dataType valueEmbedDim embedDim | |
VInProjF _ gradient device dataType valueEmbedDim embedDim = NamedModel (GLinearF 'WithBias gradient device dataType valueEmbedDim embedDim) |
type family OutProjF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the type of the out-projection layer.
OutProjF 'T5 gradient device dataType embedDim queryEmbedDim = NamedModel (GLinearF 'WithoutBias gradient device dataType embedDim queryEmbedDim) | |
OutProjF 'ByT5 gradient device dataType embedDim queryEmbedDim = OutProjF 'T5 gradient device dataType embedDim queryEmbedDim | |
OutProjF _ gradient device dataType embedDim queryEmbedDim = NamedModel (GLinearF 'WithBias gradient device dataType embedDim queryEmbedDim) |
type family DropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ... Source #
Specifies the type of the dropout layer.
DropoutF _ 'WithDropout = Dropout | |
DropoutF _ 'WithoutDropout = () |
multiHeadAttentionSpec :: forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim keyEmbedDim -> SDim valueEmbedDim -> SHasDropout hasDropout -> Double -> ModelSpec (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout) Source #
Specifies the parameters of a multi-headed attention layer.
style
: the style of the attention layer, e.g.ST5
,ByT5
, etc.gradient
: whether to compute the gradient of the attention layer.device
: the computational device on which to allocate the attention layer.dataType
: the data type of the attention layer.headDim
: the dimension of the attention heads.headEmbedDim
: the dimension of the attention head embeddings.embedDim
: the dimension of the input embeddings.queryEmbedDim
: the dimension of the query embeddings.keyEmbedDim
: the dimension of the key embeddings.valueEmbedDim
: the dimension of the value embeddings.dropoutP
: the dropout rate.
type BatchDim queryShape keyShape valueShape = (queryShape ! 0) <+> ((keyShape ! 0) <+> (valueShape ! 0)) Source #
getBatchDim :: forall m queryShape keyShape valueShape batchDim. (MonadThrow m, batchDim ~ BatchDim queryShape keyShape valueShape) => SShape queryShape -> SShape keyShape -> SShape valueShape -> m (SDim batchDim) Source #
type QuerySeqDim queryShape = queryShape ! 1 Source #
getQuerySeqDim :: forall m queryShape querySeqDim. (MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) => SShape queryShape -> m (SDim querySeqDim) Source #
getKeySeqDim :: forall m keyShape valueShape keySeqDim. (MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) => SShape keyShape -> SShape valueShape -> m (SDim keySeqDim) Source #