hasktorch-gradually-typed-0.2.0.0: experimental project for hasktorch
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.GraduallyTyped.NN.Transformer.GStack

Synopsis

Documentation

newtype GTransformerStack (stack :: Type) where Source #

Generic transformer stack.

  • stack is a stack of tranformer blocks.

Constructors

GTransformerStack :: forall stack. stack -> GTransformerStack stack 

Instances

Instances details
Generic (GTransformerStack stack) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Associated Types

type Rep (GTransformerStack stack) :: Type -> Type Source #

Show stack => Show (GTransformerStack stack) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Eq stack => Eq (GTransformerStack stack) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Ord stack => Ord (GTransformerStack stack) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

HasStateDict block => HasStateDict (GTransformerStack (Vector numLayers block)) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

(HasInitialize block generatorDevice block' generatorDevice, numLayers' ~ (numLayers + 1)) => HasInitialize (GTransformerStack (Vector numLayers' block)) generatorDevice (GTransformerStack (Vector numLayers' block')) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Methods

initialize :: MonadThrow m => ModelSpec (GTransformerStack (Vector numLayers' block)) -> Generator generatorDevice -> m (GTransformerStack (Vector numLayers' block'), Generator generatorDevice) Source #

(HasForward block (query, attentionBias) generatorDevice output generatorOutputDevice, HasForward block (output, attentionBias) generatorOutputDevice output generatorOutputDevice) => HasForward (GTransformerStack (Vector n block)) (query, attentionBias) generatorDevice output generatorOutputDevice Source #

HasForward instance for GTransformerStack in an encoder configuration.

┌───────┐  ┌───────────────┐
│ query │  │ attentionBias │
└───┬───┘  └───────┬───────┘
    │              │
    ▼              │
  block◄───────────┤
    ▼              │
  block◄───────────┤
    ▼              │
   ...            ...
    ▼              │
  block◄───────────┘
    │
    ▼
┌───────┐
│ query │
└───────┘
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Methods

forward :: MonadThrow m => GTransformerStack (Vector n block) -> (query, attentionBias) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

HasForward (GTransformerStack (Vector 0 block)) (query, attentionBias) generatorDevice query generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Methods

forward :: MonadThrow m => GTransformerStack (Vector 0 block) -> (query, attentionBias) -> Generator generatorDevice -> m (query, Generator generatorDevice) Source #

HasForward block (query, attentionBias) generatorDevice output generatorOutputDevice => HasForward (GTransformerStack (Vector 1 block)) (query, attentionBias) generatorDevice output generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Methods

forward :: MonadThrow m => GTransformerStack (Vector 1 block) -> (query, attentionBias) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

(HasForward block (query, key, attentionBias, crossAttentionBias) generatorDevice output generatorOutputDevice, HasForward block (output, key, attentionBias, crossAttentionBias) generatorOutputDevice output generatorOutputDevice) => HasForward (GTransformerStack (Vector n block)) (query, key, attentionBias, crossAttentionBias) generatorDevice output generatorOutputDevice Source #

HasForward instance for GTransformerStack in a decoder configuration.

┌───────┐  ┌─────┐  ┌───────────────┐  ┌────────────────────┐
│ query │  │ key │  │ attentionBias │  │ crossAttentionBias │
└───┬───┘  └──┬──┘  └───────┬───────┘  └─────────┬──────────┘
    │         │             │                    │
    ▼         │             │                    │
  block◄──────┤◄────────────┤◄───────────────────┤
    ▼         │             │                    │
  block◄──────┤◄────────────┤◄───────────────────┤
    ▼         │             │                    │
   ...       ...           ...                  ...
    ▼         │             │                    │
  block◄──────┘◄────────────┘◄───────────────────┘
    │
    ▼
┌───────┐
│ query │
└───────┘
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Methods

forward :: MonadThrow m => GTransformerStack (Vector n block) -> (query, key, attentionBias, crossAttentionBias) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

HasForward (GTransformerStack (Vector 0 block)) (query, key, attentionBias, crossAttentionBias) generator query generator Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Methods

forward :: MonadThrow m => GTransformerStack (Vector 0 block) -> (query, key, attentionBias, crossAttentionBias) -> Generator generator -> m (query, Generator generator) Source #

HasForward block (query, key, attentionBias, crossAttentionBias) generatorDevice output generatorOutputDevice => HasForward (GTransformerStack (Vector 1 block)) (query, key, attentionBias, crossAttentionBias) generatorDevice output generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

Methods

forward :: MonadThrow m => GTransformerStack (Vector 1 block) -> (query, key, attentionBias, crossAttentionBias) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

type Rep (GTransformerStack stack) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

type Rep (GTransformerStack stack) = D1 ('MetaData "GTransformerStack" "Torch.GraduallyTyped.NN.Transformer.GStack" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'True) (C1 ('MetaCons "GTransformerStack" 'PrefixI 'False) (S1 ('MetaSel ('Nothing :: Maybe Symbol) 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 stack)))
type ModelSpec (GTransformerStack stack) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GStack

type family EncoderStackF (style :: TransformerStyle) (numLayers :: Nat) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) where ... Source #

Equations

EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout = GTransformerStack (Vector numLayers (EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout)) 

encoderStackSpec :: forall style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout. STransformerStyle style -> SNat numLayers -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim ffnDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout) Source #

Specifies the parameters of a transformer stack in an encoder configuration.

  • style: the style of the transformer stack, e.g. ST5, SByT5, etc.
  • gradient: whether to compute the gradient of the stack's parameters.
  • device: the computational device on which the stack is allocated.
  • dataType: the data type of the stack's parameters.
  • headDim: the dimension of all transformer heads in the stack.
  • headEmbedDim: the dimension of the transformer head embeddings.
  • embedDim: the dimension of the transformer embeddings.
  • queryEmbedDim: the dimension of the transformer query embeddings.
  • ffnDim: the dimension of the feed-forward network.
  • dropoutP: the dropout rate.
  • eps: the epsilon value for numerical stability of the layer normalization.

type family DecoderStackF (style :: TransformerStyle) (numLayers :: Nat) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) where ... Source #

Equations

DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout = GTransformerStack (Vector numLayers (DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout)) 

decoderStackSpec :: forall style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout. STransformerStyle style -> SNat numLayers -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim keyEmbedDim -> SDim ffnDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout) Source #

Specifies the parameters of a transformer stack in a decoder configuration.

  • style: the style of the transformer stack, e.g. ST5, SByT5, etc.
  • gradient: whether to compute the gradient of the stack's parameters.
  • device: the computational device on which the stack is allocated.
  • dataType: the data type of the stack's parameters.
  • headDim: the dimension of all transformer heads in the stack.
  • headEmbedDim: the dimension of the transformer head embeddings.
  • embedDim: the dimension of the transformer embeddings.
  • queryEmbedDim: the dimension of the transformer query embeddings.
  • keyEmbedDim: the dimension of the transformer key embeddings.
  • ffnDim: the dimension of the feed-forward network.
  • dropoutP: the dropout rate.
  • eps: the epsilon value for numerical stability of the layer normalization.