hasktorch-gradually-typed-0.2.0.0: experimental project for hasktorch
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.GraduallyTyped.NN.Transformer.GBlock

Synopsis

Documentation

data GTransformerBlock (selfAttention :: Type) (crossAttention :: Type) (feedForwardNetwork :: Type) where Source #

Generic transformer encoder block consisting of self-attention, cross-attention, and a feed-forward network.

  • selfAttention is a self-attention layer.
  • crossAttention is a cross-attention layer.
  • feedForwardNetwork is a feed-forward layer.

TODO: Some transformers use LayerDrop, see https://arxiv.org/abs/1909.11556, during training. To support this, we will need a layer wrapper that is either the identity function or the wrapped layer based on a uniformly random draw from a supplied generator.

Constructors

GTransformerBlock 

Fields

Instances

Instances details
Generic (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Associated Types

type Rep (GTransformerBlock selfAttention crossAttention feedForwardNetwork) :: Type -> Type Source #

Methods

from :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Rep (GTransformerBlock selfAttention crossAttention feedForwardNetwork) x Source #

to :: Rep (GTransformerBlock selfAttention crossAttention feedForwardNetwork) x -> GTransformerBlock selfAttention crossAttention feedForwardNetwork Source #

(Show selfAttention, Show crossAttention, Show feedForwardNetwork) => Show (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Methods

showsPrec :: Int -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> ShowS Source #

show :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> String Source #

showList :: [GTransformerBlock selfAttention crossAttention feedForwardNetwork] -> ShowS Source #

(Eq selfAttention, Eq crossAttention, Eq feedForwardNetwork) => Eq (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Methods

(==) :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Bool Source #

(/=) :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Bool Source #

(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) => Ord (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Methods

compare :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Ordering Source #

(<) :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Bool Source #

(<=) :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Bool Source #

(>) :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Bool Source #

(>=) :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> Bool Source #

max :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork Source #

min :: GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> GTransformerBlock selfAttention crossAttention feedForwardNetwork Source #

(HasStateDict selfAttention, HasStateDict crossAttention, HasStateDict feedForwardNetwork) => HasStateDict (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Methods

fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GTransformerBlock selfAttention crossAttention feedForwardNetwork) -> StateDictKey -> m (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source #

toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GTransformerBlock selfAttention crossAttention feedForwardNetwork -> m () Source #

(HasInitialize selfAttention generatorDevice selfAttention' generatorDevice0, HasInitialize crossAttention generatorDevice0 crossAttention' generatorDevice1, HasInitialize feedForwardNetwork generatorDevice1 feedForwardNetwork' generatorOutputDevice) => HasInitialize (GTransformerBlock selfAttention crossAttention feedForwardNetwork) generatorDevice (GTransformerBlock selfAttention' crossAttention' feedForwardNetwork') generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Methods

initialize :: MonadThrow m => ModelSpec (GTransformerBlock selfAttention crossAttention feedForwardNetwork) -> Generator generatorDevice -> m (GTransformerBlock selfAttention' crossAttention' feedForwardNetwork', Generator generatorOutputDevice) Source #

(HasForward selfAttention (query, attentionBias) generatorDevice tensor0 generatorDevice0, HasForward feedForwardNetwork tensor0 generatorDevice0 output generatorOutputDevice) => HasForward (GTransformerBlock selfAttention () feedForwardNetwork) (query, attentionBias) generatorDevice output generatorOutputDevice Source #

HasForward instance for GTransformerBlock in an encoder configuration.

     ┌───────┐  ┌───────────────┐
     │ query │  │ attentionBias │
     └───┬───┘  └───────┬───────┘
         │              │
         ▼              │
  tbSelfAttention◄──────┘
         ▼
tbFeedForwardNetwork
         │
         ▼
     ┌───────┐
     │ query │
     └───────┘
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Methods

forward :: MonadThrow m => GTransformerBlock selfAttention () feedForwardNetwork -> (query, attentionBias) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

(HasForward selfAttention (query, attentionBias) generatorDevice tensor0 generatorDevice0, HasForward crossAttention (tensor0, key, crossAttentionBias) generatorDevice0 tensor1 generatorDevice1, HasForward feedForwardNetwork tensor1 generatorDevice1 output generatorOutputDevice) => HasForward (GTransformerBlock selfAttention crossAttention feedForwardNetwork) (query, key, attentionBias, crossAttentionBias) generatorDevice output generatorOutputDevice Source #

HasForward instance for GTransformerBlock in a decoder configuration.

┌──────────────────────┐  ┌───────┐  ┌─────┐  ┌────────────────────┐
│ decoderAttentionBias │  │ query │  │ key │  │ crossAttentionBias │
└──────────┬───────────┘  └───┬───┘  └──┬──┘  └─────────┬──────────┘
           │                  │         │               │
           │                  ▼         │               │
           └──────────►tdbSelfAttention │               │
                              │         │               │
                              ▼         ▼               │
                           tdbCrossAttention◄───────────┘
                              │
                              ▼
                    tdbFeedForwardNetwork
                              │
                              ▼
                          ┌───────┐
                          │ query │
                          └───────┘
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

Methods

forward :: MonadThrow m => GTransformerBlock selfAttention crossAttention feedForwardNetwork -> (query, key, attentionBias, crossAttentionBias) -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

type Rep (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

type Rep (GTransformerBlock selfAttention crossAttention feedForwardNetwork) = D1 ('MetaData "GTransformerBlock" "Torch.GraduallyTyped.NN.Transformer.GBlock" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GTransformerBlock" 'PrefixI 'True) (S1 ('MetaSel ('Just "tbSelfAttention") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 selfAttention) :*: (S1 ('MetaSel ('Just "tbCrossAttention") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 crossAttention) :*: S1 ('MetaSel ('Just "tbFeedForwardNetwork") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 feedForwardNetwork))))
type ModelSpec (GTransformerBlock selfAttention crossAttention feedForwardNetwork) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GBlock

type ModelSpec (GTransformerBlock selfAttention crossAttention feedForwardNetwork) = GTransformerBlock (ModelSpec selfAttention) (ModelSpec crossAttention) (ModelSpec feedForwardNetwork)

type family EncoderBlockF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) where ... Source #

Equations

EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout = GTransformerBlock (NamedModel (GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout)) () (NamedModel (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout)) 

encoderBlockSpec :: forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim ffnDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout) Source #

type family DecoderBlockF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (headDim :: Dim (Name Symbol) (Size Nat)) (headEmbedDim :: Dim (Name Symbol) (Size Nat)) (embedDim :: Dim (Name Symbol) (Size Nat)) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (keyEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) where ... Source #

Equations

DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout = GTransformerBlock (NamedModel (GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout)) (NamedModel (GCrossAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout)) (NamedModel (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout)) 

decoderBlockSpec :: forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim headDim -> SDim headEmbedDim -> SDim embedDim -> SDim queryEmbedDim -> SDim keyEmbedDim -> SDim ffnDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout) Source #