hasktorch-0.2.0.0: initial implementation for hasktorch based on libtorch

Safe HaskellNone
LanguageHaskell2010

Torch.Typed.NN.Transformer

Documentation

data MultiheadAttentionSpec (embedDim :: Nat) (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

MultiheadAttentionSpec 

Fields

Instances
Eq (MultiheadAttentionSpec embedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

(==) :: MultiheadAttentionSpec embedDim numHeads dtype device -> MultiheadAttentionSpec embedDim numHeads dtype device -> Bool Source #

(/=) :: MultiheadAttentionSpec embedDim numHeads dtype device -> MultiheadAttentionSpec embedDim numHeads dtype device -> Bool Source #

Show (MultiheadAttentionSpec embedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> MultiheadAttentionSpec embedDim numHeads dtype device -> ShowS Source #

show :: MultiheadAttentionSpec embedDim numHeads dtype device -> String Source #

showList :: [MultiheadAttentionSpec embedDim numHeads dtype device] -> ShowS Source #

(All KnownNat (embedDim ': (numHeads ': ([] :: [Nat]))), KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (MultiheadAttentionSpec embedDim numHeads dtype device) (MultiheadAttention embedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: MultiheadAttentionSpec embedDim numHeads dtype device -> IO (MultiheadAttention embedDim numHeads dtype device) Source #

data MultiheadAttention (embedDim :: Nat) (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

MultiheadAttention 

Fields

Instances
Show (MultiheadAttention embedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> MultiheadAttention embedDim numHeads dtype device -> ShowS Source #

show :: MultiheadAttention embedDim numHeads dtype device -> String Source #

showList :: [MultiheadAttention embedDim numHeads dtype device] -> ShowS Source #

Generic (MultiheadAttention embedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (MultiheadAttention embedDim numHeads dtype device) :: Type -> Type Source #

Methods

from :: MultiheadAttention embedDim numHeads dtype device -> Rep (MultiheadAttention embedDim numHeads dtype device) x Source #

to :: Rep (MultiheadAttention embedDim numHeads dtype device) x -> MultiheadAttention embedDim numHeads dtype device Source #

(All KnownNat (embedDim ': (numHeads ': ([] :: [Nat]))), KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (MultiheadAttentionSpec embedDim numHeads dtype device) (MultiheadAttention embedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: MultiheadAttentionSpec embedDim numHeads dtype device -> IO (MultiheadAttention embedDim numHeads dtype device) Source #

type Rep (MultiheadAttention embedDim numHeads dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (MultiheadAttention embedDim numHeads dtype device) = D1 (MetaData "MultiheadAttention" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-2MIITmBkH8kJT2gDBUcX5D" False) (C1 (MetaCons "MultiheadAttention" PrefixI True) (S1 (MetaSel (Just "mhaInProj") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Linear embedDim (embedDim * 3) dtype device)) :*: (S1 (MetaSel (Just "mhaOutProj") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Linear embedDim embedDim dtype device)) :*: S1 (MetaSel (Just "mhaDropout") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 Dropout))))

multiheadAttention :: forall embedDim numHeads seqLen batchSize headDim dtype device. (1 <= numHeads, embedDim ~ (headDim * numHeads), Mod (embedDim * 3) 3 ~ 0, Div (embedDim * 3) 3 ~ embedDim, All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim], KnownDType dtype, dtype ~ DTypePromotion dtype (SumDType dtype), StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device (SumDType dtype), SumDTypeIsValid device dtype, KnownDevice device) => MultiheadAttention embedDim numHeads dtype device -> Bool -> Tensor device Bool '[seqLen, batchSize] -> Tensor device dtype '[seqLen, batchSize, embedDim] -> IO (Tensor device dtype '[seqLen, batchSize, embedDim], Tensor device dtype '[batchSize, seqLen, seqLen]) Source #

data TransformerLMLayerSpec (embedDim :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLMLayerSpec 

Fields

Instances
Show (TransformerLMLayerSpec embedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLMLayerSpec embedDim numHeads ffnDim dtype device -> ShowS Source #

show :: TransformerLMLayerSpec embedDim numHeads ffnDim dtype device -> String Source #

showList :: [TransformerLMLayerSpec embedDim numHeads ffnDim dtype device] -> ShowS Source #

(All KnownNat (embedDim ': (numHeads ': (ffnDim ': ([] :: [Nat])))), KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerLMLayerSpec embedDim numHeads ffnDim dtype device) (TransformerLMLayer embedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMLayerSpec embedDim numHeads ffnDim dtype device -> IO (TransformerLMLayer embedDim numHeads ffnDim dtype device) Source #

data TransformerLMLayer (embedDim :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLMLayer 

Fields

Instances
(1 <= numHeads, embedDim ~ (headDim * numHeads), Mod (embedDim * 3) 3 ~ 0, Div (embedDim * 3) 3 ~ embedDim, All KnownNat (embedDim ': (numHeads ': (seqLen ': (batchSize ': (headDim ': ([] :: [Nat])))))), EndsWith (seqLen ': (batchSize ': (embedDim ': ([] :: [Nat])))) (embedDim ': ([] :: [Nat])), KnownDType dtype, dtype ~ SumDType dtype, StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, SumDTypeIsValid device dtype, KnownDevice device) => Apply FoldLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device) ((Tensor device Bool (seqLen ': (batchSize ': ([] :: [Nat]))), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': ([] :: [Nat]))))) -> IO (Tensor device Bool (seqLen ': (batchSize ': ([] :: [Nat]))), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': ([] :: [Nat])))))) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

apply :: FoldLayers -> TransformerLMLayer embedDim numHeads ffnDim dtype device -> (Tensor device Bool (seqLen ': (batchSize ': [])), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': [])))) -> IO (Tensor device Bool (seqLen ': (batchSize ': [])), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': [])))) Source #

Show (TransformerLMLayer embedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLMLayer embedDim numHeads ffnDim dtype device -> ShowS Source #

show :: TransformerLMLayer embedDim numHeads ffnDim dtype device -> String Source #

showList :: [TransformerLMLayer embedDim numHeads ffnDim dtype device] -> ShowS Source #

Generic (TransformerLMLayer embedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (TransformerLMLayer embedDim numHeads ffnDim dtype device) :: Type -> Type Source #

Methods

from :: TransformerLMLayer embedDim numHeads ffnDim dtype device -> Rep (TransformerLMLayer embedDim numHeads ffnDim dtype device) x Source #

to :: Rep (TransformerLMLayer embedDim numHeads ffnDim dtype device) x -> TransformerLMLayer embedDim numHeads ffnDim dtype device Source #

(All KnownNat (embedDim ': (numHeads ': (ffnDim ': ([] :: [Nat])))), KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerLMLayerSpec embedDim numHeads ffnDim dtype device) (TransformerLMLayer embedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMLayerSpec embedDim numHeads ffnDim dtype device -> IO (TransformerLMLayer embedDim numHeads ffnDim dtype device) Source #

type Rep (TransformerLMLayer embedDim numHeads ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (TransformerLMLayer embedDim numHeads ffnDim dtype device) = D1 (MetaData "TransformerLMLayer" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-2MIITmBkH8kJT2gDBUcX5D" False) (C1 (MetaCons "TransformerLMLayer" PrefixI True) ((S1 (MetaSel (Just "mha") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (MultiheadAttention embedDim numHeads dtype device)) :*: S1 (MetaSel (Just "attnDropout") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 Dropout)) :*: (S1 (MetaSel (Just "ln0") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (LayerNorm (embedDim ': ([] :: [Nat])) dtype device)) :*: (S1 (MetaSel (Just "ln1") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (LayerNorm (embedDim ': ([] :: [Nat])) dtype device)) :*: S1 (MetaSel (Just "mlp") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (TransformerLMMLP embedDim ffnDim dtype device))))))

transformerLMLayer :: forall numHeads ffnDim embedDim headDim seqLen batchSize dtype device. (1 <= numHeads, embedDim ~ (headDim * numHeads), Mod (embedDim * 3) 3 ~ 0, Div (embedDim * 3) 3 ~ embedDim, All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim], EndsWith '[seqLen, batchSize, embedDim] '[embedDim], KnownDType dtype, dtype ~ SumDType dtype, StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, SumDTypeIsValid device dtype, KnownDevice device) => TransformerLMLayer embedDim numHeads ffnDim dtype device -> Bool -> Tensor device Bool '[seqLen, batchSize] -> Tensor device dtype '[seqLen, batchSize, embedDim] -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]) Source #

data Activation (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

Activation 

Fields

Instances
Show (Activation dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> Activation dtype device -> ShowS Source #

show :: Activation dtype device -> String Source #

showList :: [Activation dtype device] -> ShowS Source #

Parameterized (Activation dtype device) ([] :: [Type]) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

flattenParameters :: Activation dtype device -> HList [] Source #

replaceParameters :: Activation dtype device -> HList [] -> Activation dtype device Source #

data TransformerLMMLPSpec (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLMMLPSpec 

Fields

Instances
Show (TransformerLMMLPSpec embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLMMLPSpec embedDim ffnDim dtype device -> ShowS Source #

show :: TransformerLMMLPSpec embedDim ffnDim dtype device -> String Source #

showList :: [TransformerLMMLPSpec embedDim ffnDim dtype device] -> ShowS Source #

(All KnownNat (embedDim ': (ffnDim ': ([] :: [Nat]))), KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerLMMLPSpec embedDim ffnDim dtype device) (TransformerLMMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMMLPSpec embedDim ffnDim dtype device -> IO (TransformerLMMLP embedDim ffnDim dtype device) Source #

data TransformerLMMLP (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLMMLP 

Fields

Instances
Show (TransformerLMMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLMMLP embedDim ffnDim dtype device -> ShowS Source #

show :: TransformerLMMLP embedDim ffnDim dtype device -> String Source #

showList :: [TransformerLMMLP embedDim ffnDim dtype device] -> ShowS Source #

Generic (TransformerLMMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (TransformerLMMLP embedDim ffnDim dtype device) :: Type -> Type Source #

Methods

from :: TransformerLMMLP embedDim ffnDim dtype device -> Rep (TransformerLMMLP embedDim ffnDim dtype device) x Source #

to :: Rep (TransformerLMMLP embedDim ffnDim dtype device) x -> TransformerLMMLP embedDim ffnDim dtype device Source #

(All KnownNat (embedDim ': (ffnDim ': ([] :: [Nat]))), KnownDType dtype, KnownDevice device, RandDTypeIsValid device dtype) => Randomizable (TransformerLMMLPSpec embedDim ffnDim dtype device) (TransformerLMMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMMLPSpec embedDim ffnDim dtype device -> IO (TransformerLMMLP embedDim ffnDim dtype device) Source #

type Rep (TransformerLMMLP embedDim ffnDim dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (TransformerLMMLP embedDim ffnDim dtype device) = D1 (MetaData "TransformerLMMLP" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-2MIITmBkH8kJT2gDBUcX5D" False) (C1 (MetaCons "TransformerLMMLP" PrefixI True) ((S1 (MetaSel (Just "linear0") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Linear embedDim ffnDim dtype device)) :*: (S1 (MetaSel (Just "linear1") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Linear ffnDim embedDim dtype device)) :*: S1 (MetaSel (Just "dropout0") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 Dropout))) :*: (S1 (MetaSel (Just "dropout1") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 Dropout) :*: (S1 (MetaSel (Just "activation0") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Activation dtype device)) :*: S1 (MetaSel (Just "activation1") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Activation dtype device))))))

transformerLMMLP :: forall embedDim ffnDim seqLen batchSize dtype device. TransformerLMMLP embedDim ffnDim dtype device -> Bool -> Tensor device dtype '[seqLen, batchSize, embedDim] -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]) Source #

data FoldLayers Source #

Constructors

FoldLayers 
Instances
(1 <= numHeads, embedDim ~ (headDim * numHeads), Mod (embedDim * 3) 3 ~ 0, Div (embedDim * 3) 3 ~ embedDim, All KnownNat (embedDim ': (numHeads ': (seqLen ': (batchSize ': (headDim ': ([] :: [Nat])))))), EndsWith (seqLen ': (batchSize ': (embedDim ': ([] :: [Nat])))) (embedDim ': ([] :: [Nat])), KnownDType dtype, dtype ~ SumDType dtype, StandardFloatingPointDTypeValidation device dtype, MatMulDTypeIsValid device dtype, BasicArithmeticDTypeIsValid device dtype, SumDTypeIsValid device dtype, KnownDevice device) => Apply FoldLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device) ((Tensor device Bool (seqLen ': (batchSize ': ([] :: [Nat]))), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': ([] :: [Nat]))))) -> IO (Tensor device Bool (seqLen ': (batchSize ': ([] :: [Nat]))), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': ([] :: [Nat])))))) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

apply :: FoldLayers -> TransformerLMLayer embedDim numHeads ffnDim dtype device -> (Tensor device Bool (seqLen ': (batchSize ': [])), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': [])))) -> IO (Tensor device Bool (seqLen ': (batchSize ': [])), Tensor device dtype (seqLen ': (batchSize ': (embedDim ': [])))) Source #

getHidden :: forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen batchSize dtype device. (All KnownNat '[paddingIdx, embedDim, seqLen, batchSize], (paddingIdx + 1) <= numEmbeds, 1 <= seqLen, HFoldrM IO FoldLayers (Tensor device Bool '[seqLen, batchSize], Tensor device dtype '[seqLen, batchSize, embedDim]) (HReplicateR numAttnLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device)), BasicArithmeticDTypeIsValid device dtype, ComparisonDTypeIsValid device dtype, ComparisonDTypeIsValid device Int64, KnownDevice device) => Embedding (Just paddingIdx) numEmbeds embedDim Learned dtype device -> Embedding Nothing 2048 embedDim Constant dtype device -> Dropout -> Bool -> HList (HReplicateR numAttnLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device)) -> Tensor device Int64 '[batchSize, seqLen] -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]) Source #

data TransformerLMSpec (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat) (seqLen :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLMSpec 

Fields

Instances
Show (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

showsPrec :: Int -> TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device -> ShowS Source #

show :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device -> String Source #

showList :: [TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device] -> ShowS Source #

(paddingIdx <= numEmbeds, 1 <= (numEmbeds - paddingIdx), (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds, (Div embedDim 2 * 2) ~ embedDim, All KnownNat (ffnDim ': (paddingIdx ': (numEmbeds ': (embedDim ': (seqLen ': ([] :: [Nat])))))), HReplicate numAttnLayers (TransformerLMLayerSpec embedDim numHeads ffnDim dtype device), Randomizable (HList (HReplicateR numAttnLayers (TransformerLMLayerSpec embedDim numHeads ffnDim dtype device))) (HList (HReplicateR numAttnLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device))), KnownDType dtype, RandDTypeIsValid device dtype, StandardFloatingPointDTypeValidation device Float, BasicArithmeticDTypeIsValid device Float, KnownDevice device) => Randomizable (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device -> IO (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) Source #

data TransformerLM (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat) (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat) (seqLen :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) where Source #

Constructors

TransformerLM 

Fields

Instances
Generic (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Associated Types

type Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) :: Type -> Type Source #

Methods

from :: TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device -> Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) x Source #

to :: Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) x -> TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device Source #

(paddingIdx <= numEmbeds, 1 <= (numEmbeds - paddingIdx), (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds, (Div embedDim 2 * 2) ~ embedDim, All KnownNat (ffnDim ': (paddingIdx ': (numEmbeds ': (embedDim ': (seqLen ': ([] :: [Nat])))))), HReplicate numAttnLayers (TransformerLMLayerSpec embedDim numHeads ffnDim dtype device), Randomizable (HList (HReplicateR numAttnLayers (TransformerLMLayerSpec embedDim numHeads ffnDim dtype device))) (HList (HReplicateR numAttnLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device))), KnownDType dtype, RandDTypeIsValid device dtype, StandardFloatingPointDTypeValidation device Float, BasicArithmeticDTypeIsValid device Float, KnownDevice device) => Randomizable (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

Methods

sample :: TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device -> IO (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) Source #

type Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) Source # 
Instance details

Defined in Torch.Typed.NN.Transformer

type Rep (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device) = D1 (MetaData "TransformerLM" "Torch.Typed.NN.Transformer" "hasktorch-0.2.0.0-2MIITmBkH8kJT2gDBUcX5D" False) (C1 (MetaCons "TransformerLM" PrefixI True) ((S1 (MetaSel (Just "tEmbedding") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Embedding (Just paddingIdx) numEmbeds embedDim Learned dtype device)) :*: S1 (MetaSel (Just "tPosEmbedding") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Embedding (Nothing :: Maybe Nat) 2048 embedDim Constant dtype device))) :*: (S1 (MetaSel (Just "tDropout") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 Dropout) :*: (S1 (MetaSel (Just "tLayers") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (HList (HReplicateR numAttnLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device)))) :*: S1 (MetaSel (Just "tProj") NoSourceUnpackedness NoSourceStrictness DecidedLazy) (Rec0 (Linear embedDim seqLen dtype device))))))

sinusoidal :: forall numEmbeds embedDim device. (All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds, 1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim, StandardFloatingPointDTypeValidation device Float, BasicArithmeticDTypeIsValid device Float, KnownDevice device) => Tensor device Float '[numEmbeds, embedDim] Source #

logits :: forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen batchSize dtype device. (All KnownNat '[paddingIdx, embedDim, seqLen, batchSize], (paddingIdx + 1) <= numEmbeds, 1 <= seqLen, HFoldrM IO FoldLayers (Tensor device Bool '[seqLen, batchSize], Tensor device dtype '[seqLen, batchSize, embedDim]) (HReplicateR numAttnLayers (TransformerLMLayer embedDim numHeads ffnDim dtype device)), BasicArithmeticDTypeIsValid device dtype, ComparisonDTypeIsValid device dtype, ComparisonDTypeIsValid device Int64, KnownDevice device) => TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim seqLen dtype device -> Bool -> Tensor device Int64 '[batchSize, seqLen] -> IO (Tensor device dtype '[batchSize, seqLen, seqLen]) Source #