hasktorch-gradually-typed-0.2.0.0: experimental project for hasktorch
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Synopsis

Documentation

data GCrossAttention (initialLayerNorm :: Type) (mha :: Type) (dropout :: Type) (finalLayerNorm :: Type) where Source #

Generic cross-attention layer data type.

  • initialLayerNorm: the initial layer normalization
  • mha: the multi-headed attention layer
  • dropout: the dropout layer
  • finalLayerNorm: the final layer normalization

Constructors

GCrossAttention 

Fields

  • :: forall initialLayerNorm mha dropout finalLayerNorm. { caInitialLayerNorm :: initialLayerNorm

    initial layer normalization of the cross-attention layer.

  •    , caMultiHeadAttention :: mha

    multi-headed attention layer specialized for cross-attention.

  •    , caDropout :: dropout

    dropout

  •    , caFinalLayerNorm :: finalLayerNorm

    final layer normalization of the cross-attention layer.

  •    } -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
     

Instances

Instances details
Generic (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Associated Types

type Rep (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) :: Type -> Type Source #

Methods

from :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Rep (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) x Source #

to :: Rep (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) x -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm Source #

(Show initialLayerNorm, Show mha, Show dropout, Show finalLayerNorm) => Show (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Methods

showsPrec :: Int -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> ShowS Source #

show :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> String Source #

showList :: [GCrossAttention initialLayerNorm mha dropout finalLayerNorm] -> ShowS Source #

(Eq initialLayerNorm, Eq mha, Eq dropout, Eq finalLayerNorm) => Eq (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Methods

(==) :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Bool Source #

(/=) :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Bool Source #

(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) => Ord (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Methods

compare :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Ordering Source #

(<) :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Bool Source #

(<=) :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Bool Source #

(>) :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Bool Source #

(>=) :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> Bool Source #

max :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm Source #

min :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> GCrossAttention initialLayerNorm mha dropout finalLayerNorm Source #

(HasStateDict initialLayerNorm, HasStateDict multiHeadAttention, HasStateDict dropout, HasStateDict finalLayerNorm) => HasStateDict (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Methods

fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm) -> StateDictKey -> m (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm) Source #

toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm -> m () Source #

(HasInitialize initialLayerNorm generatorDevice initialLayerNorm' generatorDevice0, HasInitialize multiHeadAttention generatorDevice0 multiHeadAttention' generatorDevice1, HasInitialize dropout generatorDevice1 dropout' generatorDevice2, HasInitialize finalLayerNorm generatorDevice2 finalLayerNorm' generatorOutputDevice) => HasInitialize (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm) generatorDevice (GCrossAttention initialLayerNorm' multiHeadAttention' dropout' finalLayerNorm') generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Methods

initialize :: MonadThrow m => ModelSpec (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm) -> Generator generatorDevice -> m (GCrossAttention initialLayerNorm' multiHeadAttention' dropout' finalLayerNorm', Generator generatorOutputDevice) Source #

(HasForward initialLayerNorm (Tensor queryGradient queryLayout queryDevice queryDataType queryShape) generatorDevice tensor0 generatorDevice0, HasForward multiHeadAttention (tensor0, Tensor keyGradient keyLayout keyDevice keyDataType keyShape, Tensor keyGradient keyLayout keyDevice keyDataType keyShape, Tensor attentionBiasGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape) generatorDevice0 tensor1 generatorDevice1, HasForward dropout tensor1 generatorDevice1 (Tensor gradient2 layout2 device2 dataType2 shape2) generatorDevice2, HasForward finalLayerNorm (Tensor (queryGradient <|> gradient2) (queryLayout <+> layout2) (queryDevice <+> device2) (queryDataType <+> dataType2) (BroadcastShapesF queryShape shape2)) generatorDevice2 output generatorOutputDevice, Catch (BroadcastShapesF queryShape shape2)) => HasForward (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm) (Tensor queryGradient queryLayout queryDevice queryDataType queryShape, Tensor keyGradient keyLayout keyDevice keyDataType keyShape, Tensor attentionBiasGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape) generatorDevice output generatorOutputDevice Source #

HasForward instance for GCrossAttention.

       ┌───────┐    ┌─────┐    ┌───────────────┐
       │ query │    │ key │    │ attentionBias │
       └───┬───┘    └──┬──┘    └───────┬───────┘
           │           │               │
┌──────────┤           │               │
│          │           │               │
│          ▼           │               │
│ (caInitialLayerNorm) │               │
│          │           │               │
│          │       ┌───┴───┐           │
│          │       │       │           │
│          ▼       ▼       ▼           │
│        caMultiheadAttention◄─────────┘
│                  │
│                  ▼
│              caDropout
│                  │
└──────►add◄───────┘
         │
         ▼
 (caFinalLayerNorm)
         │
         ▼
     ┌───────┐
     │ query │
     └───────┘
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

Methods

forward :: MonadThrow m => GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm -> (Tensor queryGradient queryLayout queryDevice queryDataType queryShape, Tensor keyGradient keyLayout keyDevice keyDataType keyShape, Tensor attentionBiasGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

type Rep (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

type Rep (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) = D1 ('MetaData "GCrossAttention" "Torch.GraduallyTyped.NN.Transformer.GCrossAttention" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GCrossAttention" 'PrefixI 'True) ((S1 ('MetaSel ('Just "caInitialLayerNorm") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 initialLayerNorm) :*: S1 ('MetaSel ('Just "caMultiHeadAttention") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 mha)) :*: (S1 ('MetaSel ('Just "caDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 dropout) :*: S1 ('MetaSel ('Just "caFinalLayerNorm") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 finalLayerNorm))))
type ModelSpec (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GCrossAttention

type ModelSpec (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) = GCrossAttention (ModelSpec initialLayerNorm) (ModelSpec mha) (ModelSpec dropout) (ModelSpec finalLayerNorm)

type family GCrossAttentionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ... Source #

Equations

GCrossAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout = GCrossAttention (CAInitialLayerNormF style gradient device dataType queryEmbedDim) (CAMultiheadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout) (CADropoutF style hasDropout) (CAFinalLayerNormF style gradient device dataType queryEmbedDim) 

type family CAInitialLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the initial layer normalization of the cross-attention layer.

Equations

CAInitialLayerNormF 'T5 gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])) 
CAInitialLayerNormF 'ByT5 gradient device dataType queryEmbedDim = CAInitialLayerNormF 'T5 gradient device dataType queryEmbedDim 
CAInitialLayerNormF 'BART _ _ _ _ = () 
CAInitialLayerNormF 'MBART gradient device dataType queryEmbedDim = CAInitialLayerNormF 'BART gradient device dataType queryEmbedDim 
CAInitialLayerNormF 'Pegasus gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) 

type family CAMultiheadAttentionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ... Source #

Specifies the multi-headed attention layer specialized for cross-attention.

Equations

CAMultiheadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout = NamedModel (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim keyEmbedDim hasDropout) 

type family CADropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ... Source #

Specifies the dropout layer of the cross-attention layer.

type family CAFinalLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the final layer normalization of the cross-attention layer.

Equations

CAFinalLayerNormF 'T5 _ _ _ _ = () 
CAFinalLayerNormF 'ByT5 gradient device dataType queryEmbedDim = CAFinalLayerNormF 'T5 gradient device dataType queryEmbedDim 
CAFinalLayerNormF 'BART gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) 
CAFinalLayerNormF 'MBART gradient device dataType queryEmbedDim = CAFinalLayerNormF 'BART gradient device dataType queryEmbedDim 
CAFinalLayerNormF 'Pegasus gradient device dataType queryEmbedDim = () 

crossAttentionSpec :: forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim keyEmbedDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (GCrossAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout) Source #

Specifies the parameters of a cross-attention layer.

  • style: the style of the transformer stack, e.g. ST5, SByT5, etc.
  • gradient: whether to compute the gradient of the stack's parameters.
  • device: the computational device on which the stack is allocated.
  • dataType: the data type of the stack's parameters.
  • headDim: the dimension of all transformer heads in the stack.
  • headEmbedDim: the dimension of the transformer head embeddings.
  • embedDim: the dimension of the transformer embeddings.
  • queryEmbedDim: the dimension of the transformer query embeddings.
  • keyEmbedDim: the dimension of the transformer key embeddings.
  • dropoutP: the dropout rate.
  • eps: the epsilon value for numerical stability of the layer normalization.