hasktorch-gradually-typed-0.2.0.0: experimental project for hasktorch
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.GraduallyTyped.NN.Transformer.Type

Synopsis

Documentation

data TransformerStyle Source #

A data type representing the style of a transformer. Every supported transformer has a constructor of this type.

Instances

Instances details
Show TransformerStyle Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Eq TransformerStyle Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

SingKind TransformerStyle Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Demote TransformerStyle = (r :: Type) Source #

SingI 'BART Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'BART Source #

SingI 'BERT Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'BERT Source #

SingI 'ByT5 Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'ByT5 Source #

SingI 'GPT2 Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'GPT2 Source #

SingI 'MBART Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'MBART Source #

SingI 'Pegasus Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'Pegasus Source #

SingI 'RoBERTa Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'RoBERTa Source #

SingI 'T5 Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

sing :: Sing 'T5 Source #

type Demote TransformerStyle Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Sing Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type family T5Sym0 :: TransformerStyle where ... Source #

Equations

T5Sym0 = 'T5 

type family ByT5Sym0 :: TransformerStyle where ... Source #

Equations

ByT5Sym0 = 'ByT5 

type family BARTSym0 :: TransformerStyle where ... Source #

Equations

BARTSym0 = 'BART 

type family MBARTSym0 :: TransformerStyle where ... Source #

Equations

MBARTSym0 = 'MBART 

type family PegasusSym0 :: TransformerStyle where ... Source #

Equations

PegasusSym0 = 'Pegasus 

type family BERTSym0 :: TransformerStyle where ... Source #

Equations

BERTSym0 = 'BERT 

type family RoBERTaSym0 :: TransformerStyle where ... Source #

Equations

RoBERTaSym0 = 'RoBERTa 

type family GPT2Sym0 :: TransformerStyle where ... Source #

Equations

GPT2Sym0 = 'GPT2 

data TransformerHead Source #

A data type representing the type of head used in a transformer.

Constructors

WithoutHead 
WithLMHead 

type family WithoutHeadSym0 :: TransformerHead where ... Source #

type family WithLMHeadSym0 :: TransformerHead where ... Source #

padded :: Integral n => n -> a -> [a] -> [a] Source #

mkTransformerInput Source #

Arguments

:: forall batchDim seqDim device m output. (MonadThrow m, SGetDim batchDim, SGetDim seqDim, Catch ('Shape '['Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize] <+> 'Shape '[batchDim, seqDim]), output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '[batchDim, seqDim])) 
=> Int

padding token id

-> SDim batchDim

batch dimension singleton

-> SDim seqDim

sequence dimension singleton

-> SDevice device

device for the tensor

-> [[Int]]

batch of input ids

-> m output

input tensor

Converts a doubly-nested list of input ids to a batched input tensor. The outer list is over batches, the inner list over sequences. The batch size is inferred from the length of the outer list. The sequence length is inferred from the length of the inner list. The input ids are padded to the maximum sequence length. The output tensor is truncated to the maximum sequence length.

type MkPosC device shape seqDim seqName seqSize output = (SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), seqDim ~ 'Dim seqName seqSize, output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") seqSize])) Source #

mkPos Source #

Arguments

:: forall m gradient layout device dataType shape seqDim seqName seqSize output. (MonadThrow m, MkPosC device shape seqDim seqName seqSize output) 
=> Tensor gradient layout device dataType shape

input tensor

-> m output

positions of the input tokens

Computes absolute positions of the input tokens. Given an input tensor of shape [batchDim, Dim seqName seqSize], returns a tensor of shape [Dim "*" seqSize].

data MkAbsPos Source #

Constructors

MkAbsPos 
MkAbsPosWithOffset 

Fields

Instances

Instances details
Generic MkAbsPos Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Rep MkAbsPos :: Type -> Type Source #

Show MkAbsPos Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Eq MkAbsPos Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Ord MkAbsPos Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasStateDict MkAbsPos Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasInitialize MkAbsPos generatorDevice MkAbsPos generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

initialize :: MonadThrow m => ModelSpec MkAbsPos -> Generator generatorDevice -> m (MkAbsPos, Generator generatorDevice) Source #

MkPosC device shape seqDim seqName seqSize output => HasForward MkAbsPos (Tensor gradient layout device dataType shape) generatorDevice (Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") seqSize])) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

forward :: MonadThrow m => MkAbsPos -> Tensor gradient layout device dataType shape -> Generator generatorDevice -> m (Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") seqSize]), Generator generatorDevice) Source #

type Rep MkAbsPos Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Rep MkAbsPos = D1 ('MetaData "MkAbsPos" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MkAbsPos" 'PrefixI 'False) (U1 :: Type -> Type) :+: C1 ('MetaCons "MkAbsPosWithOffset" 'PrefixI 'True) (S1 ('MetaSel ('Just "absPosOffset") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Int)))
type ModelSpec MkAbsPos Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

mkRelPos' :: Int -> Int -> Int -> Int -> [[Int]] Source #

Computes relative positions of the input tokens to the encoder.

>>> mkRelPos' 32 128 21 17
[[0,17,18,19,20,21,22,23,24,24,24,24,25,25,25,25,26],[1,0,17,18,19,20,21,22,23,24,24,24,24,25,25,25,25],[2,1,0,17,18,19,20,21,22,23,24,24,24,24,25,25,25],[3,2,1,0,17,18,19,20,21,22,23,24,24,24,24,25,25],[4,3,2,1,0,17,18,19,20,21,22,23,24,24,24,24,25],[5,4,3,2,1,0,17,18,19,20,21,22,23,24,24,24,24],[6,5,4,3,2,1,0,17,18,19,20,21,22,23,24,24,24],[7,6,5,4,3,2,1,0,17,18,19,20,21,22,23,24,24],[8,7,6,5,4,3,2,1,0,17,18,19,20,21,22,23,24],[8,8,7,6,5,4,3,2,1,0,17,18,19,20,21,22,23],[8,8,8,7,6,5,4,3,2,1,0,17,18,19,20,21,22],[8,8,8,8,7,6,5,4,3,2,1,0,17,18,19,20,21],[9,8,8,8,8,7,6,5,4,3,2,1,0,17,18,19,20],[9,9,8,8,8,8,7,6,5,4,3,2,1,0,17,18,19],[9,9,9,8,8,8,8,7,6,5,4,3,2,1,0,17,18],[9,9,9,9,8,8,8,8,7,6,5,4,3,2,1,0,17],[10,9,9,9,9,8,8,8,8,7,6,5,4,3,2,1,0],[10,10,9,9,9,9,8,8,8,8,7,6,5,4,3,2,1],[10,10,10,9,9,9,9,8,8,8,8,7,6,5,4,3,2],[10,10,10,10,9,9,9,9,8,8,8,8,7,6,5,4,3],[10,10,10,10,10,9,9,9,9,8,8,8,8,7,6,5,4]]

type MkRelPosC device shape seqDim seqName seqSize output = (SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), seqDim ~ 'Dim seqName seqSize, Catch ('['Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize] <+> '['Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]), output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])) Source #

mkRelPos Source #

Arguments

:: forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output. (MonadThrow m, MkRelPosC device shape seqDim seqName seqSize output) 
=> SDim relPosEncBucketDim

bucket dimension

-> Int

maximum distance

-> Tensor gradient layout device dataType shape

input tensor

-> m output

relative positions of the input tokens

Computes relative positions of the input tokens to the encoder. Given an input tensor of shape [batchDim, Dim seqName seqSize], returns a tensor of shape [1, Dim "*" seqSize, Dim "*" seqSize].

mkDecoderRelPos' :: Int -> Int -> Int -> Int -> [[Int]] Source #

Computes relative positions of the input tokens to the decoder.

>>> mkDecoderRelPos' 32 128 21 17
[[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[3,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[4,3,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0],[5,4,3,2,1,0,0,0,0,0,0,0,0,0,0,0,0],[6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0,0],[7,6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0],[8,7,6,5,4,3,2,1,0,0,0,0,0,0,0,0,0],[9,8,7,6,5,4,3,2,1,0,0,0,0,0,0,0,0],[10,9,8,7,6,5,4,3,2,1,0,0,0,0,0,0,0],[11,10,9,8,7,6,5,4,3,2,1,0,0,0,0,0,0],[12,11,10,9,8,7,6,5,4,3,2,1,0,0,0,0,0],[13,12,11,10,9,8,7,6,5,4,3,2,1,0,0,0,0],[14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,0,0],[15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,0],[16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0],[16,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1],[16,16,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2],[17,16,16,16,15,14,13,12,11,10,9,8,7,6,5,4,3],[17,17,16,16,16,15,14,13,12,11,10,9,8,7,6,5,4]]

mkDecoderRelPos Source #

Arguments

:: forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output. (MonadThrow m, MkRelPosC device shape seqDim seqName seqSize output) 
=> SDim relPosEncBucketDim

bucket dimension

-> Int

maximum distance

-> Tensor gradient layout device dataType shape

decoder input tensor

-> m output

relative positions of the input tokens

Computes relative positions of the input tokens to the decoder. Given an input tensor of shape [batchDim, Dim seqName seqSize], returns a tensor of shape [1, Dim "*" seqSize, Dim "*" seqSize].

data MkRelPos (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) where Source #

Constructors

MkRelPos 

Fields

MkDecoderRelPos 

Fields

Instances

Instances details
Generic (MkRelPos relPosEncBucketDim) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Rep (MkRelPos relPosEncBucketDim) :: Type -> Type Source #

Methods

from :: MkRelPos relPosEncBucketDim -> Rep (MkRelPos relPosEncBucketDim) x Source #

to :: Rep (MkRelPos relPosEncBucketDim) x -> MkRelPos relPosEncBucketDim Source #

Show (MkRelPos relPosEncBucketDim) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

showsPrec :: Int -> MkRelPos relPosEncBucketDim -> ShowS Source #

show :: MkRelPos relPosEncBucketDim -> String Source #

showList :: [MkRelPos relPosEncBucketDim] -> ShowS Source #

HasStateDict (MkRelPos relPosEncBucketDim) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (MkRelPos relPosEncBucketDim) -> StateDictKey -> m (MkRelPos relPosEncBucketDim) Source #

toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> MkRelPos relPosEncBucketDim -> m () Source #

HasInitialize (MkRelPos relPosEncBucketDim) generatorDevice (MkRelPos relPosEncBucketDim) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

initialize :: MonadThrow m => ModelSpec (MkRelPos relPosEncBucketDim) -> Generator generatorDevice -> m (MkRelPos relPosEncBucketDim, Generator generatorDevice) Source #

MkRelPosC device shape seqDim seqName seqSize output => HasForward (MkRelPos relPosEncBucketDim) (Tensor gradient layout device dataType shape) generatorDevice (Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

forward :: MonadThrow m => MkRelPos relPosEncBucketDim -> Tensor gradient layout device dataType shape -> Generator generatorDevice -> m (Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]), Generator generatorDevice) Source #

type Rep (MkRelPos relPosEncBucketDim) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Rep (MkRelPos relPosEncBucketDim) = D1 ('MetaData "MkRelPos" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MkRelPos" 'PrefixI 'True) (S1 ('MetaSel ('Just "relPosEncBucketDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim relPosEncBucketDim)) :*: S1 ('MetaSel ('Just "relPosMaxDistance") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Int)) :+: C1 ('MetaCons "MkDecoderRelPos" 'PrefixI 'True) (S1 ('MetaSel ('Just "decoderRelPosEncBucketDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim relPosEncBucketDim)) :*: S1 ('MetaSel ('Just "decoderRelPosMaxDistance") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Int)))
type ModelSpec (MkRelPos relPosEncBucketDim) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type ModelSpec (MkRelPos relPosEncBucketDim) = MkRelPos relPosEncBucketDim

type MkTransformerPaddingMaskC layout device dataType shape output = (SGetDevice device, Catch (dataType <+> 'DataType 'Int64), Catch (BroadcastShapesF shape ('Shape '[])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device ('DataType 'Bool) (BroadcastShapesF shape ('Shape '[]))) Source #

mkTransformerPaddingMask Source #

Arguments

:: forall m gradient layout device dataType shape output. (MonadThrow m, MkTransformerPaddingMaskC layout device dataType shape output) 
=> Int

padding token id

-> Tensor gradient layout device dataType shape

input tensor

-> m output

padding mask

Computes the padding mask for a transformer. Given an input tensor of shape [batchDim, Dim seqName seqSize], returns a tensor of shape [batchDim, Dim "*" seqSize].

newtype MkTransformerPaddingMask Source #

Constructors

MkTransformerPaddingMask 

Fields

Instances

Instances details
Generic MkTransformerPaddingMask Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Rep MkTransformerPaddingMask :: Type -> Type Source #

Show MkTransformerPaddingMask Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasStateDict MkTransformerPaddingMask Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasInitialize MkTransformerPaddingMask generatorDevice MkTransformerPaddingMask generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

MkTransformerPaddingMaskC layout device dataType shape output => HasForward MkTransformerPaddingMask (Tensor gradient layout device dataType shape) generatorDevice output generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

forward :: MonadThrow m => MkTransformerPaddingMask -> Tensor gradient layout device dataType shape -> Generator generatorDevice -> m (output, Generator generatorDevice) Source #

type Rep MkTransformerPaddingMask Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Rep MkTransformerPaddingMask = D1 ('MetaData "MkTransformerPaddingMask" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'True) (C1 ('MetaCons "MkTransformerPaddingMask" 'PrefixI 'True) (S1 ('MetaSel ('Just "padTokenId") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Int)))
type ModelSpec MkTransformerPaddingMask Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), Catch (gradient <+> 'Gradient 'WithoutGradient), Catch (dataType <+> 'DataType 'Bool), Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]))) Source #

mkTransformerAttentionMask Source #

Arguments

:: forall m transformerDataType gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output) 
=> SDataType transformerDataType

data type singleton of the transformer

-> Double

attention mask bias (typically a large negative number)

-> Tensor gradient layout device dataType shape

encoder padding mask

-> m output 

Creates a bidirectional attention mask for a transformer. Given a padding mask of shape [batchDim, seqDim], returns a tensor of shape [batchDim, seqDim, seqDim].

data MkTransformerAttentionMask (dataType :: DataType DType) where Source #

Constructors

MkTransformerAttentionMask 

Fields

Instances

Instances details
Generic (MkTransformerAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Rep (MkTransformerAttentionMask dataType) :: Type -> Type Source #

Show (MkTransformerAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasStateDict (MkTransformerAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasInitialize (MkTransformerAttentionMask dataType) generatorDevice (MkTransformerAttentionMask dataType) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

initialize :: MonadThrow m => ModelSpec (MkTransformerAttentionMask dataType) -> Generator generatorDevice -> m (MkTransformerAttentionMask dataType, Generator generatorDevice) Source #

MkTransformerAttentionMaskC dataType inputGradient inputLayout inputDevice inputDataType inputShape seqDim output => HasForward (MkTransformerAttentionMask dataType) (Tensor inputGradient inputLayout inputDevice inputDataType inputShape) generatorDevice output generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

forward :: MonadThrow m => MkTransformerAttentionMask dataType -> Tensor inputGradient inputLayout inputDevice inputDataType inputShape -> Generator generatorDevice -> m (output, Generator generatorDevice) Source #

type Rep (MkTransformerAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Rep (MkTransformerAttentionMask dataType) = D1 ('MetaData "MkTransformerAttentionMask" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MkTransformerAttentionMask" 'PrefixI 'True) (S1 ('MetaSel ('Just "attentionMaskDataType") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDataType dataType)) :*: S1 ('MetaSel ('Just "attentionMaskBias") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Double)))
type ModelSpec (MkTransformerAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), Catch seqDim, Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)), Catch (BroadcastShapesF (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]))) Source #

mkTransformerDecoderAttentionMask Source #

Arguments

:: forall m transformerDataType gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output) 
=> SDataType transformerDataType

data type singleton of the transformer

-> Double

attention mask bias (typically a large negative number)

-> Tensor gradient layout device dataType shape

decoder padding mask

-> m output 

Creates a causal attention mask for a transformer decoder. Given a padding mask of shape [batchDim, seqDim], returns a tensor of shape [batchDim, seqDim, seqDim].

data MkTransformerDecoderAttentionMask (dataType :: DataType DType) where Source #

Constructors

MkTransformerDecoderAttentionMask 

Fields

Instances

Instances details
Generic (MkTransformerDecoderAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Rep (MkTransformerDecoderAttentionMask dataType) :: Type -> Type Source #

Show (MkTransformerDecoderAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasStateDict (MkTransformerDecoderAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasInitialize (MkTransformerDecoderAttentionMask dataType) generatorDevice (MkTransformerDecoderAttentionMask dataType) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

initialize :: MonadThrow m => ModelSpec (MkTransformerDecoderAttentionMask dataType) -> Generator generatorDevice -> m (MkTransformerDecoderAttentionMask dataType, Generator generatorDevice) Source #

MkTransformerDecoderAttentionMaskC dataType decoderInputLayout decoderInputDevice decoderInputShape seqDim output => HasForward (MkTransformerDecoderAttentionMask dataType) (Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape) generatorDevice output generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

forward :: MonadThrow m => MkTransformerDecoderAttentionMask dataType -> Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape -> Generator generatorDevice -> m (output, Generator generatorDevice) Source #

type Rep (MkTransformerDecoderAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Rep (MkTransformerDecoderAttentionMask dataType) = D1 ('MetaData "MkTransformerDecoderAttentionMask" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MkTransformerDecoderAttentionMask" 'PrefixI 'True) (S1 ('MetaSel ('Just "decoderAttentionMaskDataType") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDataType dataType)) :*: S1 ('MetaSel ('Just "decoderAttentionMaskBias") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Double)))
type ModelSpec (MkTransformerDecoderAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), SGetShape decoderInputShape, decoderInputSeqDim ~ (decoderInputShape ! 1), Catch (gradient <+> 'Gradient 'WithoutGradient), Catch (dataType <+> 'DataType 'Bool), Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim]))) Source #

mkTransformerCrossAttentionMask Source #

Arguments

:: forall m transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output) 
=> SDataType transformerDataType

data type singleton of the transformer

-> SShape decoderInputShape

decoder input shape

-> Double

attention mask bias (typically a large negative number)

-> Tensor gradient layout device dataType shape

encoder padding mask

-> m output 

Creates a cross-attention mask for an encoder-decoder transformer. Given an encoder padding mask of shape [batchDim, seqDim], and the shape [batchDim, decoderSeqDim] of the decoder's input, returns a tensor of shape [batchDim, decoderSeqDim, seqDim].

data MkTransformerCrossAttentionMask (dataType :: DataType DType) where Source #

Constructors

MkTransformerCrossAttentionMask 

Fields

Instances

Instances details
Generic (MkTransformerCrossAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Rep (MkTransformerCrossAttentionMask dataType) :: Type -> Type Source #

Show (MkTransformerCrossAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasStateDict (MkTransformerCrossAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

HasInitialize (MkTransformerCrossAttentionMask dataType) generatorDevice (MkTransformerCrossAttentionMask dataType) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

initialize :: MonadThrow m => ModelSpec (MkTransformerCrossAttentionMask dataType) -> Generator generatorDevice -> m (MkTransformerCrossAttentionMask dataType, Generator generatorDevice) Source #

MkTransformerCrossAttentionMaskC dataType decoderInputShape decoderInputSeqDim inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape seqDim output => HasForward (MkTransformerCrossAttentionMask dataType) (Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape, Tensor inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape) generatorDevice output generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

forward :: MonadThrow m => MkTransformerCrossAttentionMask dataType -> (Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape, Tensor inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape) -> Generator generatorDevice -> m (output, Generator generatorDevice) Source #

type Rep (MkTransformerCrossAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Rep (MkTransformerCrossAttentionMask dataType) = D1 ('MetaData "MkTransformerCrossAttentionMask" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MkTransformerCrossAttentionMask" 'PrefixI 'True) (S1 ('MetaSel ('Just "crossAttentionMaskDataType") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDataType dataType)) :*: S1 ('MetaSel ('Just "crossAttentionMaskBias") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Double)))
type ModelSpec (MkTransformerCrossAttentionMask dataType) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

data ShiftRight fillValue where Source #

Constructors

ShiftRight 

Fields

  • :: forall fillValue. fillValue

    fill value for shift right

  • -> ShiftRight fillValue
     

Instances

Instances details
Generic (ShiftRight fillValue) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Associated Types

type Rep (ShiftRight fillValue) :: Type -> Type Source #

Methods

from :: ShiftRight fillValue -> Rep (ShiftRight fillValue) x Source #

to :: Rep (ShiftRight fillValue) x -> ShiftRight fillValue Source #

Show fillValue => Show (ShiftRight fillValue) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

showsPrec :: Int -> ShiftRight fillValue -> ShowS Source #

show :: ShiftRight fillValue -> String Source #

showList :: [ShiftRight fillValue] -> ShowS Source #

Eq fillValue => Eq (ShiftRight fillValue) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

(==) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source #

(/=) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source #

Ord fillValue => Ord (ShiftRight fillValue) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

compare :: ShiftRight fillValue -> ShiftRight fillValue -> Ordering Source #

(<) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source #

(<=) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source #

(>) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source #

(>=) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source #

max :: ShiftRight fillValue -> ShiftRight fillValue -> ShiftRight fillValue Source #

min :: ShiftRight fillValue -> ShiftRight fillValue -> ShiftRight fillValue Source #

HasStateDict (ShiftRight fillValue) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

(input ~ Tensor inputGradient inputLayout inputDevice inputDataType inputShape, SGetLayout inputLayout, SGetDevice inputDevice, SGetDataType inputDataType, SGetShape inputShape, inputBatchDim ~ (inputShape ! 0), inputSeqDim ~ (inputShape ! 1), Scalar fillValue, rightShiftedInput ~ Tensor (inputGradient <|> 'Gradient 'WithoutGradient) inputLayout inputDevice inputDataType (ReplaceDimF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) (inputShape <+> 'Shape '[inputBatchDim, inputSeqDim]) (AddDimF inputSeqDim ('Dim ('Name "*") ('Size 1))))) => HasForward (ShiftRight fillValue) input generator rightShiftedInput generator Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

forward :: MonadThrow m => ShiftRight fillValue -> input -> Generator generator -> m (rightShiftedInput, Generator generator) Source #

HasInitialize (ShiftRight fillValue) generatorDevice (ShiftRight fillValue) generatorDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

Methods

initialize :: MonadThrow m => ModelSpec (ShiftRight fillValue) -> Generator generatorDevice -> m (ShiftRight fillValue, Generator generatorDevice) Source #

type Rep (ShiftRight fillValue) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type Rep (ShiftRight fillValue) = D1 ('MetaData "ShiftRight" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "ShiftRight" 'PrefixI 'False) (S1 ('MetaSel ('Nothing :: Maybe Symbol) 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 fillValue)))
type ModelSpec (ShiftRight fillValue) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.Type

type ModelSpec (ShiftRight fillValue) = ShiftRight fillValue