hasktorch-0.2.0.0: Functional differentiable programming in Haskell
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.Typed.NN.Transformer

Documentation

residual :: forall {device :: (DeviceType, Nat)} {dtype :: DType} {dtype' :: DType} {m} {shape :: [Nat]} {shape' :: [Nat]} {b}. (BasicArithmeticDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype', BasicArithmeticDTypeIsValid device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')), Monad m) => (Tensor device dtype shape -> m (Tensor device dtype' shape')) -> (Tensor device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')) (CheckBroadcast shape shape' (ComputeBroadcast (ReverseImpl shape ('[] :: [Nat])) (ReverseImpl shape' ('[] :: [Nat])))) -> m b) -> Tensor device dtype shape -> m b Source #

data MultiheadAttentionSpec (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

MultiheadAttentionSpec 

Fields

Instances

Instances details
Show (MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> ShowS Source #

show :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> String Source #

showList :: [MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device] -> ShowS Source #

Eq (MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

(==) :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> Bool Source #

(/=) :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> Bool Source #

(All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads], KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device) (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> IO (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source #

data MultiheadAttention (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

MultiheadAttention 

Fields

Instances

Instances details
Generic (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) :: Type -> Type Source #

Methods

from :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device -> Rep (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) x Source #

to :: Rep (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) x -> MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device Source #

Show (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device -> ShowS Source #

show :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device -> String Source #

showList :: [MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device] -> ShowS Source #

Parameterized (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Parameters (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) :: [Type] Source #

Methods

flattenParameters :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device -> HList (Parameters (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device)) Source #

replaceParameters :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device -> HList (Parameters (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device)) -> MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device Source #

(All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads], KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device) (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device -> IO (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source #

type Rep (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) = D1 ('MetaData "MultiheadAttention" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-F6yFRaDiRF49lpq95SVuR8" 'False) (C1 ('MetaCons "MultiheadAttention" 'PrefixI 'True) ((S1 ('MetaSel ('Just "mhaQInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Linear embedDim embedDim dtype device)) :*: S1 ('MetaSel ('Just "mhaKInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Linear kEmbedDim embedDim dtype device))) :*: (S1 ('MetaSel ('Just "mhaVInProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Linear vEmbedDim embedDim dtype device)) :*: (S1 ('MetaSel ('Just "mhaOutProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Linear embedDim embedDim dtype device)) :*: S1 ('MetaSel ('Just "mhaDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 Dropout)))))
type Parameters (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Parameters (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device) = GParameters (Rep (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device))

multiheadAttention Source #

Arguments

:: forall embedDim kEmbedDim vEmbedDim numHeads seqLen seqLen' batchSize headDim dtype device. (1 <= numHeads, embedDim ~ (headDim * numHeads), All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim], KnownDType dtype, StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype, SumDTypeIsValid device dtype, KnownDevice device) 
=> MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device

multi-head attention model ADT

-> Bool

switch between training mode and evaluation mode (turns random dropout on and off)

-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])

optional attention mask

-> Maybe (Tensor device 'Bool '[batchSize, seqLen])

optional key padding mask

-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])

optional key relations

-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])

optional value relations

-> Tensor device dtype '[batchSize, seqLen', embedDim]

query representation

-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]

key representation

-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]

value representation

-> IO (Tensor device dtype '[batchSize, seqLen', embedDim], Tensor device dtype '[batchSize, seqLen', seqLen])

attention and attention averaged over heads

data TransformerMLPSpec (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerMLPSpec 

Fields

Instances

Instances details
Show (TransformerMLPSpec embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS Source #

show :: TransformerMLPSpec embedDim ffnDim dtype device -> String Source #

showList :: [TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS Source #

Eq (TransformerMLPSpec embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

(==) :: TransformerMLPSpec embedDim ffnDim dtype device -> TransformerMLPSpec embedDim ffnDim dtype device -> Bool Source #

(/=) :: TransformerMLPSpec embedDim ffnDim dtype device -> TransformerMLPSpec embedDim ffnDim dtype device -> Bool Source #

(All KnownNat '[embedDim, ffnDim], KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerMLPSpec embedDim ffnDim dtype device) (TransformerMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerMLPSpec embedDim ffnDim dtype device -> IO (TransformerMLP embedDim ffnDim dtype device) Source #

data TransformerMLP (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerMLP 

Fields

Instances

Instances details
Generic (TransformerMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (TransformerMLP embedDim ffnDim dtype device) :: Type -> Type Source #

Methods

from :: TransformerMLP embedDim ffnDim dtype device -> Rep (TransformerMLP embedDim ffnDim dtype device) x Source #

to :: Rep (TransformerMLP embedDim ffnDim dtype device) x -> TransformerMLP embedDim ffnDim dtype device Source #

Show (TransformerMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS Source #

show :: TransformerMLP embedDim ffnDim dtype device -> String Source #

showList :: [TransformerMLP embedDim ffnDim dtype device] -> ShowS Source #

Parameterized (TransformerMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Parameters (TransformerMLP embedDim ffnDim dtype device) :: [Type] Source #

Methods

flattenParameters :: TransformerMLP embedDim ffnDim dtype device -> HList (Parameters (TransformerMLP embedDim ffnDim dtype device)) Source #

replaceParameters :: TransformerMLP embedDim ffnDim dtype device -> HList (Parameters (TransformerMLP embedDim ffnDim dtype device)) -> TransformerMLP embedDim ffnDim dtype device Source #

(All KnownNat '[embedDim, ffnDim], KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerMLPSpec embedDim ffnDim dtype device) (TransformerMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerMLPSpec embedDim ffnDim dtype device -> IO (TransformerMLP embedDim ffnDim dtype device) Source #

type Rep (TransformerMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (TransformerMLP embedDim ffnDim dtype device) = D1 ('MetaData "TransformerMLP" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-F6yFRaDiRF49lpq95SVuR8" 'False) (C1 ('MetaCons "TransformerMLP" 'PrefixI 'True) ((S1 ('MetaSel ('Just "linear0") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Linear embedDim ffnDim dtype device)) :*: S1 ('MetaSel ('Just "linear1") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Linear ffnDim embedDim dtype device))) :*: (S1 ('MetaSel ('Just "dropout0") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 Dropout) :*: (S1 ('MetaSel ('Just "dropout1") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 Dropout) :*: S1 ('MetaSel ('Just "ln") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (LayerNorm '[embedDim] dtype device))))))
type Parameters (TransformerMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Parameters (TransformerMLP embedDim ffnDim dtype device) = GParameters (Rep (TransformerMLP embedDim ffnDim dtype device))

transformerMLP Source #

Arguments

:: forall embedDim ffnDim seqLen batchSize dtype device. (BasicArithmeticDTypeIsValid device dtype, StandardFloatingPointDTypeValidation device dtype, KnownNat embedDim, IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) 
=> TransformerMLP embedDim ffnDim dtype device

MLP model ADT for transformer

-> Bool

switch between training mode and evaluation mode (turns random dropout on and off)

-> Tensor device dtype '[seqLen, batchSize, embedDim] 
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim]) 

data TransformerLayerSpec (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLayerSpec 

Fields

Instances

Instances details
Show (TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> ShowS Source #

show :: TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> String Source #

showList :: [TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device] -> ShowS Source #

Eq (TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

(==) :: TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> Bool Source #

(/=) :: TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> Bool Source #

(All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, ffnDim], KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> IO (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source #

data TransformerLayer (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLayer 

Fields

Instances

Instances details
(1 <= numHeads, embedDim ~ (headDim * numHeads), All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim], IsSuffixOf '[embedDim] '[batchSize, seqLen, embedDim], KnownDType dtype, StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype, SumDTypeIsValid device dtype, KnownDevice device) => Apply' (FoldLayers batchSize seqLen dtype device) (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device, IO (Tensor device dtype '[batchSize, seqLen, embedDim])) (IO (Tensor device dtype '[batchSize, seqLen, embedDim])) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

apply' :: FoldLayers batchSize seqLen dtype device -> (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device, IO (Tensor device dtype '[batchSize, seqLen, embedDim])) -> IO (Tensor device dtype '[batchSize, seqLen, embedDim]) Source #

Generic (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) :: Type -> Type Source #

Methods

from :: TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> Rep (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) x Source #

to :: Rep (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) x -> TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device Source #

Show (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> ShowS Source #

show :: TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> String Source #

showList :: [TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device] -> ShowS Source #

Parameterized (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Parameters (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) :: [Type] Source #

Methods

flattenParameters :: TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> HList (Parameters (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)) Source #

replaceParameters :: TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> HList (Parameters (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)) -> TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device Source #

(All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, ffnDim], KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device -> IO (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source #

type Rep (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) = D1 ('MetaData "TransformerLayer" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-F6yFRaDiRF49lpq95SVuR8" 'False) (C1 ('MetaCons "TransformerLayer" 'PrefixI 'True) ((S1 ('MetaSel ('Just "transformerLayer_mha") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device)) :*: S1 ('MetaSel ('Just "transformerLayer_attnDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 Dropout)) :*: (S1 ('MetaSel ('Just "transformerLayer_ln") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (LayerNorm '[embedDim] dtype device)) :*: S1 ('MetaSel ('Just "transformerLayer_mlp") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (TransformerMLP embedDim ffnDim dtype device)))))
type Parameters (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Parameters (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device) = GParameters (Rep (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))

transformerLayer Source #

Arguments

:: forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat) (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat) dtype device. (1 <= numHeads, embedDim ~ (headDim * numHeads), All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim], IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim], KnownDType dtype, dtype ~ SumDType dtype, StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, SumDTypeIsValid device dtype, KnownDevice device) 
=> TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device

transformer layer model ADT

-> Bool

switch between training mode and evaluation mode (turns random dropout on and off)

-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])

optional attention mask

-> Maybe (Tensor device 'Bool '[batchSize, seqLen])

optional key padding mask

-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])

optional key relations

-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])

optional value relations

-> Tensor device dtype '[batchSize, seqLen', embedDim]

query representation

-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]

key representation

-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]

value representation

-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])

transformer layer output representation

data TransformerLMSpec (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLMSpec 

Fields

  • :: forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device. { lmDropoutSpec :: DropoutSpec

    dropout spec

  •    , lmLayerSpec :: TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device

    spec for each and every transformer layer

  •    } -> TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
     

Instances

Instances details
Show (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> ShowS Source #

show :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> String Source #

showList :: [TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device] -> ShowS Source #

Eq (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

(==) :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> Bool Source #

(/=) :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> Bool Source #

(paddingIdx <= numEmbeds, 1 <= (numEmbeds - paddingIdx), 1 <= Div embedDim 2, (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds, (Div embedDim 2 * 2) ~ embedDim, All KnownNat '[ffnDim, paddingIdx, numEmbeds, embedDim], HReplicate numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device), Randomizable (HList (HReplicateR numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device))) (HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))), KnownDType dtype, RandDTypeIsValid device dtype, StandardFloatingPointDTypeValidation device 'Float, BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) => Randomizable (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> IO (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source #

data TransformerLM (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLM 

Fields

Instances

Instances details
Generic (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) :: Type -> Type Source #

Methods

from :: TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) x Source #

to :: Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) x -> TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device Source #

Show (HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))) => Show (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> ShowS Source #

show :: TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> String Source #

showList :: [TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device] -> ShowS Source #

(layers ~ HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device), Parameterized (HList layers), HAppendFD (Parameters (HList layers)) '[Parameter device dtype '[numEmbeds, embedDim], Parameter device dtype '[numEmbeds]] (Parameters (HList layers) ++ '[Parameter device dtype '[numEmbeds, embedDim], Parameter device dtype '[numEmbeds]])) => Parameterized (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Parameters (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) :: [Type] Source #

Methods

flattenParameters :: TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> HList (Parameters (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)) Source #

replaceParameters :: TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> HList (Parameters (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)) -> TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device Source #

(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize], (paddingIdx + 1) <= numEmbeds, 1 <= seqLen, HFoldrM IO (FoldLayers batchSize seqLen dtype device) (Tensor device dtype '[batchSize, seqLen, embedDim]) (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device)) (Tensor device dtype '[batchSize, seqLen, embedDim]), BasicArithmeticDTypeIsValid device dtype, ComparisonDTypeIsValid device dtype, ComparisonDTypeIsValid device 'Int64, KnownDType dtype, KnownDevice device) => HasForward (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) (Tensor device 'Int64 '[batchSize, seqLen]) (Tensor device dtype '[batchSize, seqLen, numEmbeds]) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

forward :: TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> Tensor device 'Int64 '[batchSize, seqLen] -> Tensor device dtype '[batchSize, seqLen, numEmbeds] Source #

forwardStoch :: TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> Tensor device 'Int64 '[batchSize, seqLen] -> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds]) Source #

(paddingIdx <= numEmbeds, 1 <= (numEmbeds - paddingIdx), 1 <= Div embedDim 2, (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds, (Div embedDim 2 * 2) ~ embedDim, All KnownNat '[ffnDim, paddingIdx, numEmbeds, embedDim], HReplicate numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device), Randomizable (HList (HReplicateR numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device))) (HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))), KnownDType dtype, RandDTypeIsValid device dtype, StandardFloatingPointDTypeValidation device 'Float, BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) => Randomizable (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> IO (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source #

type Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) = D1 ('MetaData "TransformerLM" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-F6yFRaDiRF49lpq95SVuR8" 'False) (C1 ('MetaCons "TransformerLM" 'PrefixI 'True) ((S1 ('MetaSel ('Just "tEmbedding") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Embedding ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device)) :*: S1 ('MetaSel ('Just "tPosEmbedding") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Embedding ('Nothing :: Maybe Nat) 2048 embedDim 'Constant dtype device))) :*: (S1 ('MetaSel ('Just "tDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 Dropout) :*: (S1 ('MetaSel ('Just "tLayers") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device)))) :*: S1 ('MetaSel ('Just "tProj") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (Linear embedDim numEmbeds dtype device))))))
type Parameters (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Parameters (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) = GParameters (Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device))

data FoldLayers (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) Source #

Constructors

FoldLayers 

Fields

Instances

Instances details
(1 <= numHeads, embedDim ~ (headDim * numHeads), All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim], IsSuffixOf '[embedDim] '[batchSize, seqLen, embedDim], KnownDType dtype, StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype, SumDTypeIsValid device dtype, KnownDevice device) => Apply' (FoldLayers batchSize seqLen dtype device) (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device, IO (Tensor device dtype '[batchSize, seqLen, embedDim])) (IO (Tensor device dtype '[batchSize, seqLen, embedDim])) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

apply' :: FoldLayers batchSize seqLen dtype device -> (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device, IO (Tensor device dtype '[batchSize, seqLen, embedDim])) -> IO (Tensor device dtype '[batchSize, seqLen, embedDim]) Source #

transformerLM :: forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen batchSize dtype device. (All KnownNat '[paddingIdx, embedDim, seqLen, batchSize], (paddingIdx + 1) <= numEmbeds, 1 <= seqLen, HFoldrM IO (FoldLayers batchSize seqLen dtype device) (Tensor device dtype '[batchSize, seqLen, embedDim]) (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device)) (Tensor device dtype '[batchSize, seqLen, embedDim]), BasicArithmeticDTypeIsValid device dtype, ComparisonDTypeIsValid device dtype, ComparisonDTypeIsValid device 'Int64, KnownDType dtype, KnownDevice device) => TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device -> Bool -> Tensor device 'Int64 '[batchSize, seqLen] -> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds]) Source #

sinusoidal :: forall numEmbeds embedDim device. (All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds, 1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim, StandardFloatingPointDTypeValidation device 'Float, BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) => Tensor device 'Float '[numEmbeds, embedDim] Source #