Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Synopsis
- data TransformerStyle
- type family T5Sym0 :: TransformerStyle where ...
- type family ByT5Sym0 :: TransformerStyle where ...
- type family BARTSym0 :: TransformerStyle where ...
- type family MBARTSym0 :: TransformerStyle where ...
- type family PegasusSym0 :: TransformerStyle where ...
- type family BERTSym0 :: TransformerStyle where ...
- type family RoBERTaSym0 :: TransformerStyle where ...
- type family GPT2Sym0 :: TransformerStyle where ...
- data STransformerStyle :: TransformerStyle -> Type where
- ST5 :: STransformerStyle ('T5 :: TransformerStyle)
- SByT5 :: STransformerStyle ('ByT5 :: TransformerStyle)
- SBART :: STransformerStyle ('BART :: TransformerStyle)
- SMBART :: STransformerStyle ('MBART :: TransformerStyle)
- SPegasus :: STransformerStyle ('Pegasus :: TransformerStyle)
- SBERT :: STransformerStyle ('BERT :: TransformerStyle)
- SRoBERTa :: STransformerStyle ('RoBERTa :: TransformerStyle)
- SGPT2 :: STransformerStyle ('GPT2 :: TransformerStyle)
- data TransformerHead
- type family WithoutHeadSym0 :: TransformerHead where ...
- type family WithLMHeadSym0 :: TransformerHead where ...
- data STransformerHead :: TransformerHead -> Type where
- padded :: Integral n => n -> a -> [a] -> [a]
- mkTransformerInput :: forall batchDim seqDim device m output. (MonadThrow m, SGetDim batchDim, SGetDim seqDim, Catch ('Shape '['Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize] <+> 'Shape '[batchDim, seqDim]), output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '[batchDim, seqDim])) => Int -> SDim batchDim -> SDim seqDim -> SDevice device -> [[Int]] -> m output
- type MkPosC device shape seqDim seqName seqSize output = (SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), seqDim ~ 'Dim seqName seqSize, output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") seqSize]))
- mkPos :: forall m gradient layout device dataType shape seqDim seqName seqSize output. (MonadThrow m, MkPosC device shape seqDim seqName seqSize output) => Tensor gradient layout device dataType shape -> m output
- data MkAbsPos
- = MkAbsPos
- | MkAbsPosWithOffset {
- absPosOffset :: Int
- mkRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
- type MkRelPosC device shape seqDim seqName seqSize output = (SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), seqDim ~ 'Dim seqName seqSize, Catch ('['Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize] <+> '['Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]), output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]))
- mkRelPos :: forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output. (MonadThrow m, MkRelPosC device shape seqDim seqName seqSize output) => SDim relPosEncBucketDim -> Int -> Tensor gradient layout device dataType shape -> m output
- mkDecoderRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
- mkDecoderRelPos :: forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output. (MonadThrow m, MkRelPosC device shape seqDim seqName seqSize output) => SDim relPosEncBucketDim -> Int -> Tensor gradient layout device dataType shape -> m output
- data MkRelPos (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) where
- MkRelPos :: forall relPosEncBucketDim. {..} -> MkRelPos relPosEncBucketDim
- MkDecoderRelPos :: forall relPosEncBucketDim. {..} -> MkRelPos relPosEncBucketDim
- type MkTransformerPaddingMaskC layout device dataType shape output = (SGetDevice device, Catch (dataType <+> 'DataType 'Int64), Catch (BroadcastShapesF shape ('Shape '[])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device ('DataType 'Bool) (BroadcastShapesF shape ('Shape '[])))
- mkTransformerPaddingMask :: forall m gradient layout device dataType shape output. (MonadThrow m, MkTransformerPaddingMaskC layout device dataType shape output) => Int -> Tensor gradient layout device dataType shape -> m output
- newtype MkTransformerPaddingMask = MkTransformerPaddingMask {
- padTokenId :: Int
- type MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), Catch (gradient <+> 'Gradient 'WithoutGradient), Catch (dataType <+> 'DataType 'Bool), Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])))
- mkTransformerAttentionMask :: forall m transformerDataType gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output) => SDataType transformerDataType -> Double -> Tensor gradient layout device dataType shape -> m output
- data MkTransformerAttentionMask (dataType :: DataType DType) where
- MkTransformerAttentionMask :: forall dataType. {..} -> MkTransformerAttentionMask dataType
- type MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), Catch seqDim, Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)), Catch (BroadcastShapesF (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])))
- mkTransformerDecoderAttentionMask :: forall m transformerDataType gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output) => SDataType transformerDataType -> Double -> Tensor gradient layout device dataType shape -> m output
- data MkTransformerDecoderAttentionMask (dataType :: DataType DType) where
- MkTransformerDecoderAttentionMask :: forall dataType. {..} -> MkTransformerDecoderAttentionMask dataType
- type MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), SGetShape decoderInputShape, decoderInputSeqDim ~ (decoderInputShape ! 1), Catch (gradient <+> 'Gradient 'WithoutGradient), Catch (dataType <+> 'DataType 'Bool), Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])))
- mkTransformerCrossAttentionMask :: forall m transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output) => SDataType transformerDataType -> SShape decoderInputShape -> Double -> Tensor gradient layout device dataType shape -> m output
- data MkTransformerCrossAttentionMask (dataType :: DataType DType) where
- MkTransformerCrossAttentionMask :: forall dataType. {..} -> MkTransformerCrossAttentionMask dataType
- data ShiftRight fillValue where
- ShiftRight :: forall fillValue. fillValue -> ShiftRight fillValue
Documentation
data TransformerStyle Source #
A data type representing the style of a transformer. Every supported transformer has a constructor of this type.
T5 |
|
ByT5 |
|
BART |
|
MBART |
|
Pegasus |
|
BERT |
|
RoBERTa |
|
GPT2 |
|
Instances
Show TransformerStyle Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type | |
Eq TransformerStyle Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type (==) :: TransformerStyle -> TransformerStyle -> Bool Source # (/=) :: TransformerStyle -> TransformerStyle -> Bool Source # | |
SingKind TransformerStyle Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type type Demote TransformerStyle = (r :: Type) Source # fromSing :: forall (a :: TransformerStyle). Sing a -> Demote TransformerStyle Source # toSing :: Demote TransformerStyle -> SomeSing TransformerStyle Source # | |
SingI 'BART Source # | |
SingI 'BERT Source # | |
SingI 'ByT5 Source # | |
SingI 'GPT2 Source # | |
SingI 'MBART Source # | |
SingI 'Pegasus Source # | |
SingI 'RoBERTa Source # | |
SingI 'T5 Source # | |
type Demote TransformerStyle Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type | |
type Sing Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type |
type family PegasusSym0 :: TransformerStyle where ... Source #
type family RoBERTaSym0 :: TransformerStyle where ... Source #
data STransformerStyle :: TransformerStyle -> Type where Source #
ST5 :: STransformerStyle ('T5 :: TransformerStyle) | |
SByT5 :: STransformerStyle ('ByT5 :: TransformerStyle) | |
SBART :: STransformerStyle ('BART :: TransformerStyle) | |
SMBART :: STransformerStyle ('MBART :: TransformerStyle) | |
SPegasus :: STransformerStyle ('Pegasus :: TransformerStyle) | |
SBERT :: STransformerStyle ('BERT :: TransformerStyle) | |
SRoBERTa :: STransformerStyle ('RoBERTa :: TransformerStyle) | |
SGPT2 :: STransformerStyle ('GPT2 :: TransformerStyle) |
data TransformerHead Source #
A data type representing the type of head used in a transformer.
Instances
SingKind TransformerHead Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type type Demote TransformerHead = (r :: Type) Source # fromSing :: forall (a :: TransformerHead). Sing a -> Demote TransformerHead Source # toSing :: Demote TransformerHead -> SomeSing TransformerHead Source # | |
SingI 'WithLMHead Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type sing :: Sing 'WithLMHead Source # | |
SingI 'WithoutHead Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type sing :: Sing 'WithoutHead Source # | |
type Demote TransformerHead Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type | |
type Sing Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type |
type family WithoutHeadSym0 :: TransformerHead where ... Source #
type family WithLMHeadSym0 :: TransformerHead where ... Source #
data STransformerHead :: TransformerHead -> Type where Source #
:: forall batchDim seqDim device m output. (MonadThrow m, SGetDim batchDim, SGetDim seqDim, Catch ('Shape '['Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize] <+> 'Shape '[batchDim, seqDim]), output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '[batchDim, seqDim])) | |
=> Int | padding token id |
-> SDim batchDim | batch dimension singleton |
-> SDim seqDim | sequence dimension singleton |
-> SDevice device | device for the tensor |
-> [[Int]] | batch of input ids |
-> m output | input tensor |
Converts a doubly-nested list of input ids to a batched input tensor. The outer list is over batches, the inner list over sequences. The batch size is inferred from the length of the outer list. The sequence length is inferred from the length of the inner list. The input ids are padded to the maximum sequence length. The output tensor is truncated to the maximum sequence length.
type MkPosC device shape seqDim seqName seqSize output = (SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), seqDim ~ 'Dim seqName seqSize, output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") seqSize])) Source #
:: forall m gradient layout device dataType shape seqDim seqName seqSize output. (MonadThrow m, MkPosC device shape seqDim seqName seqSize output) | |
=> Tensor gradient layout device dataType shape | input tensor |
-> m output | positions of the input tokens |
Computes absolute positions of the input tokens.
Given an input tensor of shape [batchDim, Dim seqName seqSize]
,
returns a tensor of shape [Dim "*" seqSize]
.
Instances
mkRelPos' :: Int -> Int -> Int -> Int -> [[Int]] Source #
Computes relative positions of the input tokens to the encoder.
>>>
mkRelPos' 32 128 21 17
[[0,17,18,19,20,21,22,23,24,24,24,24,25,25,25,25,26],[1,0,17,18,19,20,21,22,23,24,24,24,24,25,25,25,25],[2,1,0,17,18,19,20,21,22,23,24,24,24,24,25,25,25],[3,2,1,0,17,18,19,20,21,22,23,24,24,24,24,25,25],[4,3,2,1,0,17,18,19,20,21,22,23,24,24,24,24,25],[5,4,3,2,1,0,17,18,19,20,21,22,23,24,24,24,24],[6,5,4,3,2,1,0,17,18,19,20,21,22,23,24,24,24],[7,6,5,4,3,2,1,0,17,18,19,20,21,22,23,24,24],[8,7,6,5,4,3,2,1,0,17,18,19,20,21,22,23,24],[8,8,7,6,5,4,3,2,1,0,17,18,19,20,21,22,23],[8,8,8,7,6,5,4,3,2,1,0,17,18,19,20,21,22],[8,8,8,8,7,6,5,4,3,2,1,0,17,18,19,20,21],[9,8,8,8,8,7,6,5,4,3,2,1,0,17,18,19,20],[9,9,8,8,8,8,7,6,5,4,3,2,1,0,17,18,19],[9,9,9,8,8,8,8,7,6,5,4,3,2,1,0,17,18],[9,9,9,9,8,8,8,8,7,6,5,4,3,2,1,0,17],[10,9,9,9,9,8,8,8,8,7,6,5,4,3,2,1,0],[10,10,9,9,9,9,8,8,8,8,7,6,5,4,3,2,1],[10,10,10,9,9,9,9,8,8,8,8,7,6,5,4,3,2],[10,10,10,10,9,9,9,9,8,8,8,8,7,6,5,4,3],[10,10,10,10,10,9,9,9,9,8,8,8,8,7,6,5,4]]
type MkRelPosC device shape seqDim seqName seqSize output = (SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), seqDim ~ 'Dim seqName seqSize, Catch ('['Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize] <+> '['Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]), output ~ Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])) Source #
:: forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output. (MonadThrow m, MkRelPosC device shape seqDim seqName seqSize output) | |
=> SDim relPosEncBucketDim | bucket dimension |
-> Int | maximum distance |
-> Tensor gradient layout device dataType shape | input tensor |
-> m output | relative positions of the input tokens |
Computes relative positions of the input tokens to the encoder.
Given an input tensor of shape [batchDim, Dim seqName seqSize]
,
returns a tensor of shape [1, Dim "*" seqSize, Dim "*" seqSize]
.
mkDecoderRelPos' :: Int -> Int -> Int -> Int -> [[Int]] Source #
Computes relative positions of the input tokens to the decoder.
>>>
mkDecoderRelPos' 32 128 21 17
[[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[3,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[4,3,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0],[5,4,3,2,1,0,0,0,0,0,0,0,0,0,0,0,0],[6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0,0],[7,6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0],[8,7,6,5,4,3,2,1,0,0,0,0,0,0,0,0,0],[9,8,7,6,5,4,3,2,1,0,0,0,0,0,0,0,0],[10,9,8,7,6,5,4,3,2,1,0,0,0,0,0,0,0],[11,10,9,8,7,6,5,4,3,2,1,0,0,0,0,0,0],[12,11,10,9,8,7,6,5,4,3,2,1,0,0,0,0,0],[13,12,11,10,9,8,7,6,5,4,3,2,1,0,0,0,0],[14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,0,0],[15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,0],[16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0],[16,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1],[16,16,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2],[17,16,16,16,15,14,13,12,11,10,9,8,7,6,5,4,3],[17,17,16,16,16,15,14,13,12,11,10,9,8,7,6,5,4]]
:: forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output. (MonadThrow m, MkRelPosC device shape seqDim seqName seqSize output) | |
=> SDim relPosEncBucketDim | bucket dimension |
-> Int | maximum distance |
-> Tensor gradient layout device dataType shape | decoder input tensor |
-> m output | relative positions of the input tokens |
Computes relative positions of the input tokens to the decoder.
Given an input tensor of shape [batchDim, Dim seqName seqSize]
,
returns a tensor of shape [1, Dim "*" seqSize, Dim "*" seqSize]
.
data MkRelPos (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) where Source #
MkRelPos | |
| |
MkDecoderRelPos | |
|
Instances
Generic (MkRelPos relPosEncBucketDim) Source # | |
Show (MkRelPos relPosEncBucketDim) Source # | |
HasStateDict (MkRelPos relPosEncBucketDim) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (MkRelPos relPosEncBucketDim) -> StateDictKey -> m (MkRelPos relPosEncBucketDim) Source # toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> MkRelPos relPosEncBucketDim -> m () Source # | |
HasInitialize (MkRelPos relPosEncBucketDim) generatorDevice (MkRelPos relPosEncBucketDim) generatorDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type initialize :: MonadThrow m => ModelSpec (MkRelPos relPosEncBucketDim) -> Generator generatorDevice -> m (MkRelPos relPosEncBucketDim, Generator generatorDevice) Source # | |
MkRelPosC device shape seqDim seqName seqSize output => HasForward (MkRelPos relPosEncBucketDim) (Tensor gradient layout device dataType shape) generatorDevice (Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])) generatorDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type forward :: MonadThrow m => MkRelPos relPosEncBucketDim -> Tensor gradient layout device dataType shape -> Generator generatorDevice -> m (Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '['Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]), Generator generatorDevice) Source # | |
type Rep (MkRelPos relPosEncBucketDim) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type type Rep (MkRelPos relPosEncBucketDim) = D1 ('MetaData "MkRelPos" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MkRelPos" 'PrefixI 'True) (S1 ('MetaSel ('Just "relPosEncBucketDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim relPosEncBucketDim)) :*: S1 ('MetaSel ('Just "relPosMaxDistance") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Int)) :+: C1 ('MetaCons "MkDecoderRelPos" 'PrefixI 'True) (S1 ('MetaSel ('Just "decoderRelPosEncBucketDim") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDim relPosEncBucketDim)) :*: S1 ('MetaSel ('Just "decoderRelPosMaxDistance") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Int))) | |
type ModelSpec (MkRelPos relPosEncBucketDim) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type |
type MkTransformerPaddingMaskC layout device dataType shape output = (SGetDevice device, Catch (dataType <+> 'DataType 'Int64), Catch (BroadcastShapesF shape ('Shape '[])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device ('DataType 'Bool) (BroadcastShapesF shape ('Shape '[]))) Source #
mkTransformerPaddingMask Source #
:: forall m gradient layout device dataType shape output. (MonadThrow m, MkTransformerPaddingMaskC layout device dataType shape output) | |
=> Int | padding token id |
-> Tensor gradient layout device dataType shape | input tensor |
-> m output | padding mask |
Computes the padding mask for a transformer.
Given an input tensor of shape [batchDim, Dim seqName seqSize]
,
returns a tensor of shape [batchDim, Dim "*" seqSize]
.
newtype MkTransformerPaddingMask Source #
Instances
type MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), Catch (gradient <+> 'Gradient 'WithoutGradient), Catch (dataType <+> 'DataType 'Bool), Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]))) Source #
mkTransformerAttentionMask Source #
:: forall m transformerDataType gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output) | |
=> SDataType transformerDataType | data type singleton of the transformer |
-> Double | attention mask bias (typically a large negative number) |
-> Tensor gradient layout device dataType shape | encoder padding mask |
-> m output |
Creates a bidirectional attention mask for a transformer.
Given a padding mask of shape [batchDim, seqDim]
,
returns a tensor of shape [batchDim, seqDim, seqDim]
.
data MkTransformerAttentionMask (dataType :: DataType DType) where Source #
MkTransformerAttentionMask | |
|
Instances
type MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), Catch seqDim, Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)), Catch (BroadcastShapesF (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (BroadcastShapesF ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]) (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)) ('Shape '['Dim ('Name "*") ('Size 1), seqDim, seqDim]))) Source #
mkTransformerDecoderAttentionMask Source #
:: forall m transformerDataType gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output) | |
=> SDataType transformerDataType | data type singleton of the transformer |
-> Double | attention mask bias (typically a large negative number) |
-> Tensor gradient layout device dataType shape | decoder padding mask |
-> m output |
Creates a causal attention mask for a transformer decoder.
Given a padding mask of shape [batchDim, seqDim]
,
returns a tensor of shape [batchDim, seqDim, seqDim]
.
data MkTransformerDecoderAttentionMask (dataType :: DataType DType) where Source #
MkTransformerDecoderAttentionMask | |
|
Instances
type MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output = (SGetLayout layout, SGetDevice device, SGetShape shape, seqDim ~ (shape ! 1), SGetShape decoderInputShape, decoderInputSeqDim ~ (decoderInputShape ! 1), Catch (gradient <+> 'Gradient 'WithoutGradient), Catch (dataType <+> 'DataType 'Bool), Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape), Catch (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])), output ~ Tensor ('Gradient 'WithoutGradient) (layout <+> 'Layout 'Dense) device transformerDataType (BroadcastShapesF (UnsqueezeF ('SelectDim ('ByIndex 1)) shape) ('Shape '['Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim]))) Source #
mkTransformerCrossAttentionMask Source #
:: forall m transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output. (MonadThrow m, MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output) | |
=> SDataType transformerDataType | data type singleton of the transformer |
-> SShape decoderInputShape | decoder input shape |
-> Double | attention mask bias (typically a large negative number) |
-> Tensor gradient layout device dataType shape | encoder padding mask |
-> m output |
Creates a cross-attention mask for an encoder-decoder transformer.
Given an encoder padding mask of shape [batchDim, seqDim]
,
and the shape [batchDim, decoderSeqDim]
of the decoder's input,
returns a tensor of shape [batchDim, decoderSeqDim, seqDim]
.
data MkTransformerCrossAttentionMask (dataType :: DataType DType) where Source #
MkTransformerCrossAttentionMask | |
|
Instances
Generic (MkTransformerCrossAttentionMask dataType) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type from :: MkTransformerCrossAttentionMask dataType -> Rep (MkTransformerCrossAttentionMask dataType) x Source # to :: Rep (MkTransformerCrossAttentionMask dataType) x -> MkTransformerCrossAttentionMask dataType Source # | |
Show (MkTransformerCrossAttentionMask dataType) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type | |
HasStateDict (MkTransformerCrossAttentionMask dataType) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (MkTransformerCrossAttentionMask dataType) -> StateDictKey -> m (MkTransformerCrossAttentionMask dataType) Source # toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> MkTransformerCrossAttentionMask dataType -> m () Source # | |
HasInitialize (MkTransformerCrossAttentionMask dataType) generatorDevice (MkTransformerCrossAttentionMask dataType) generatorDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type initialize :: MonadThrow m => ModelSpec (MkTransformerCrossAttentionMask dataType) -> Generator generatorDevice -> m (MkTransformerCrossAttentionMask dataType, Generator generatorDevice) Source # | |
MkTransformerCrossAttentionMaskC dataType decoderInputShape decoderInputSeqDim inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape seqDim output => HasForward (MkTransformerCrossAttentionMask dataType) (Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape, Tensor inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape) generatorDevice output generatorDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type forward :: MonadThrow m => MkTransformerCrossAttentionMask dataType -> (Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape, Tensor inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape) -> Generator generatorDevice -> m (output, Generator generatorDevice) Source # | |
type Rep (MkTransformerCrossAttentionMask dataType) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type type Rep (MkTransformerCrossAttentionMask dataType) = D1 ('MetaData "MkTransformerCrossAttentionMask" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "MkTransformerCrossAttentionMask" 'PrefixI 'True) (S1 ('MetaSel ('Just "crossAttentionMaskDataType") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 (SDataType dataType)) :*: S1 ('MetaSel ('Just "crossAttentionMaskBias") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 Double))) | |
type ModelSpec (MkTransformerCrossAttentionMask dataType) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type type ModelSpec (MkTransformerCrossAttentionMask dataType) = MkTransformerCrossAttentionMask dataType |
data ShiftRight fillValue where Source #
ShiftRight | |
|
Instances
Generic (ShiftRight fillValue) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type from :: ShiftRight fillValue -> Rep (ShiftRight fillValue) x Source # to :: Rep (ShiftRight fillValue) x -> ShiftRight fillValue Source # | |
Show fillValue => Show (ShiftRight fillValue) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type | |
Eq fillValue => Eq (ShiftRight fillValue) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type (==) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source # (/=) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source # | |
Ord fillValue => Ord (ShiftRight fillValue) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type compare :: ShiftRight fillValue -> ShiftRight fillValue -> Ordering Source # (<) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source # (<=) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source # (>) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source # (>=) :: ShiftRight fillValue -> ShiftRight fillValue -> Bool Source # max :: ShiftRight fillValue -> ShiftRight fillValue -> ShiftRight fillValue Source # min :: ShiftRight fillValue -> ShiftRight fillValue -> ShiftRight fillValue Source # | |
HasStateDict (ShiftRight fillValue) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (ShiftRight fillValue) -> StateDictKey -> m (ShiftRight fillValue) Source # toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> ShiftRight fillValue -> m () Source # | |
(input ~ Tensor inputGradient inputLayout inputDevice inputDataType inputShape, SGetLayout inputLayout, SGetDevice inputDevice, SGetDataType inputDataType, SGetShape inputShape, inputBatchDim ~ (inputShape ! 0), inputSeqDim ~ (inputShape ! 1), Scalar fillValue, rightShiftedInput ~ Tensor (inputGradient <|> 'Gradient 'WithoutGradient) inputLayout inputDevice inputDataType (ReplaceDimF ('SelectDim ('ByIndex 1 :: By Symbol Natural)) (inputShape <+> 'Shape '[inputBatchDim, inputSeqDim]) (AddDimF inputSeqDim ('Dim ('Name "*") ('Size 1))))) => HasForward (ShiftRight fillValue) input generator rightShiftedInput generator Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type forward :: MonadThrow m => ShiftRight fillValue -> input -> Generator generator -> m (rightShiftedInput, Generator generator) Source # | |
HasInitialize (ShiftRight fillValue) generatorDevice (ShiftRight fillValue) generatorDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type initialize :: MonadThrow m => ModelSpec (ShiftRight fillValue) -> Generator generatorDevice -> m (ShiftRight fillValue, Generator generatorDevice) Source # | |
type Rep (ShiftRight fillValue) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type type Rep (ShiftRight fillValue) = D1 ('MetaData "ShiftRight" "Torch.GraduallyTyped.NN.Transformer.Type" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "ShiftRight" 'PrefixI 'False) (S1 ('MetaSel ('Nothing :: Maybe Symbol) 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 fillValue))) | |
type ModelSpec (ShiftRight fillValue) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.Type |