Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Synopsis
- data GGate (layer0 :: Type) (activation :: Type) (layer1 :: Type) where
- data GTransformerFeedForwardNetwork (inputLayerNorm :: Type) (inputTransformation :: Type) (activation :: Type) (activationDropout :: Type) (outputProjection :: Type) (outputDropout :: Type) (outputLayerNorm :: Type) where
- GTransformerFeedForwardNetwork :: forall inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm. {..} -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm
- type family GTransformerFeedForwardNetworkF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ...
- type family FFNInputLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- type family FFNInputTransformationF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- type family FFNActivationF (style :: TransformerStyle) :: Type where ...
- type family FFNActivationDropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ...
- type family FFNOutputProjectionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- type family FFNOutputDropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ...
- type family FFNOutputLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ...
- transformerFeedForwardNetworkSpec :: forall style gradient device dataType queryEmbedDim ffnDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim queryEmbedDim -> SDim ffnDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout)
Documentation
data GGate (layer0 :: Type) (activation :: Type) (layer1 :: Type) where Source #
Generic two-layer gate with activation function.
layer0
is the first layer.activation
is the activation function.layer1
is the second layer.
GGate | |
|
Instances
Generic (GGate layer0 activation layer1) Source # | |
(Show layer0, Show activation, Show layer1) => Show (GGate layer0 activation layer1) Source # | |
(Eq layer0, Eq activation, Eq layer1) => Eq (GGate layer0 activation layer1) Source # | |
(Ord layer0, Ord activation, Ord layer1) => Ord (GGate layer0 activation layer1) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork compare :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Ordering Source # (<) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source # (<=) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source # (>) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source # (>=) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source # max :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> GGate layer0 activation layer1 Source # min :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> GGate layer0 activation layer1 Source # | |
(HasStateDict layer0, HasStateDict activation, HasStateDict layer1) => HasStateDict (GGate layer0 activation layer1) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GGate layer0 activation layer1) -> StateDictKey -> m (GGate layer0 activation layer1) Source # toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GGate layer0 activation layer1 -> m () Source # | |
(HasInitialize layer0 generatorDevice layer0' generatorDevice0, HasInitialize activation generatorDevice0 activation' generatorDevice1, HasInitialize layer1 generatorDevice1 layer1' generatorOutputDevice) => HasInitialize (GGate layer0 activation layer1) generatorDevice (GGate layer0' activation' layer1') generatorOutputDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork initialize :: MonadThrow m => ModelSpec (GGate layer0 activation layer1) -> Generator generatorDevice -> m (GGate layer0' activation' layer1', Generator generatorOutputDevice) Source # | |
(HasForward layer0 (Tensor gradient layout device dataType shape) generatorDevice (Tensor gradient' layout' device' dataType' shape') generatorDevice', HasForward activation (Tensor gradient' layout' device' dataType' shape') generatorDevice' (Tensor gradient' layout' device' dataType' shape') generatorDevice', HasForward layer1 (Tensor gradient layout device dataType shape) generatorDevice' (Tensor gradient' layout' device' dataType' shape') generatorDevice'', output ~ Tensor gradient' layout' device' dataType' shape', generatorOutputDevice ~ generatorDevice'') => HasForward (GGate layer0 activation layer1) (Tensor gradient layout device dataType shape) generatorDevice output generatorOutputDevice Source # | |
type Rep (GGate layer0 activation layer1) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork type Rep (GGate layer0 activation layer1) = D1 ('MetaData "GGate" "Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GGate" 'PrefixI 'True) (S1 ('MetaSel ('Just "gateLayer0") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 layer0) :*: (S1 ('MetaSel ('Just "gateActivation") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 activation) :*: S1 ('MetaSel ('Just "gateLayer1") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 layer1)))) | |
type ModelSpec (GGate layer0 activation layer1) Source # | |
data GTransformerFeedForwardNetwork (inputLayerNorm :: Type) (inputTransformation :: Type) (activation :: Type) (activationDropout :: Type) (outputProjection :: Type) (outputDropout :: Type) (outputLayerNorm :: Type) where Source #
Generic transformer feed-forward network.
inputLayerNorm
is the layer normalization for the input.inputTransformation
is the input transformation.activation
is the activation function.activationDropout
is the activation dropout layer.outputProjection
is the output projection.outputDropout
is the dropout layer for the output.outputLayerNorm
is the layer normalization for the output.
GTransformerFeedForwardNetwork | |
|
Instances
Generic (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork type Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) :: Type -> Type Source # from :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) x Source # to :: Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) x -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm Source # | |
(Show inputLayerNorm, Show inputTransformation, Show activation, Show activationDropout, Show outputProjection, Show outputDropout, Show outputLayerNorm) => Show (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork showsPrec :: Int -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> ShowS Source # show :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> String Source # showList :: [GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm] -> ShowS Source # | |
(Eq inputLayerNorm, Eq inputTransformation, Eq activation, Eq activationDropout, Eq outputProjection, Eq outputDropout, Eq outputLayerNorm) => Eq (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork (==) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source # (/=) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source # | |
(Ord inputLayerNorm, Ord inputTransformation, Ord activation, Ord activationDropout, Ord outputProjection, Ord outputDropout, Ord outputLayerNorm) => Ord (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork compare :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Ordering Source # (<) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source # (<=) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source # (>) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source # (>=) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source # max :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm Source # min :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm Source # | |
(HasStateDict inputLayerNorm, HasStateDict inputTransformation, HasStateDict activation, HasStateDict activationDropout, HasStateDict outputProjection, HasStateDict outputDropout, HasStateDict outputLayerNorm) => HasStateDict (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) -> StateDictKey -> m (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> m () Source # | |
(HasInitialize inputLayerNorm generatorDevice inputLayerNorm' generatorDevice0, HasInitialize inputTransformation generatorDevice0 inputTransformation' generatorDevice1, HasInitialize activation generatorDevice1 activation' generatorDevice2, HasInitialize activationDropout generatorDevice2 activationDropout' generatorDevice3, HasInitialize outputProjection generatorDevice3 outputProjection' generatorDevice4, HasInitialize outputDropout generatorDevice4 outputDropout' generatorDevice5, HasInitialize outputLayerNorm generatorDevice5 outputLayerNorm' generatorOutputDevice) => HasInitialize (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) generatorDevice (GTransformerFeedForwardNetwork inputLayerNorm' inputTransformation' activation' activationDropout' outputProjection' outputDropout' outputLayerNorm') generatorOutputDevice Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork initialize :: MonadThrow m => ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) -> Generator generatorDevice -> m (GTransformerFeedForwardNetwork inputLayerNorm' inputTransformation' activation' activationDropout' outputProjection' outputDropout' outputLayerNorm', Generator generatorOutputDevice) Source # | |
(HasForward inputLayerNorm (Tensor queryGradient queryLayout queryDevice queryDataType queryShape) generatorDevice tensor0 generatorDevice0, HasForward inputTransformation tensor0 generatorDevice0 tensor1 generatorDevice1, HasForward activation tensor1 generatorDevice1 tensor2 generatorDevice2, HasForward activationDropout tensor2 generatorDevice2 tensor3 generatorDevice3, HasForward outputProjection tensor3 generatorDevice3 tensor4 generatorDevice4, HasForward outputDropout tensor4 generatorDevice4 (Tensor queryGradient5 queryLayout5 queryDevice5 queryDataType5 queryShape5) generatorDevice5, HasForward outputLayerNorm (Tensor (queryGradient <|> queryGradient5) (queryLayout <+> queryLayout5) (queryDevice <+> queryDevice5) (queryDataType <+> queryDataType5) (BroadcastShapesF queryShape queryShape5)) generatorDevice5 output generatorOutputDevice, Catch (BroadcastShapesF queryShape queryShape5)) => HasForward (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) (Tensor queryGradient queryLayout queryDevice queryDataType queryShape) generatorDevice output generatorOutputDevice Source # |
┌───────┐ │ query ├────────┐ └───┬───┘ │ │ │ ▼ │ (ffnInputLayerNorm) │ ▼ │ ffnInputTransformation │ ▼ │ ffnActivation │ ▼ │ (ffnActivationDropout) │ ▼ │ ffnOutputProjecton │ ▼ │ ffnOutputDropout │ │ │ ▼ │ add◄──────────┘ │ ▼ (ffnOutputLayerNorm) │ ▼ ┌───────┐ │ query │ └───────┘ |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork forward :: MonadThrow m => GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Tensor queryGradient queryLayout queryDevice queryDataType queryShape -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source # | |
type Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork type Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) = D1 ('MetaData "GTransformerFeedForwardNetwork" "Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GTransformerFeedForwardNetwork" 'PrefixI 'True) ((S1 ('MetaSel ('Just "ffnInputLayerNorm") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 inputLayerNorm) :*: (S1 ('MetaSel ('Just "ffnInputTransformation") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 inputTransformation) :*: S1 ('MetaSel ('Just "ffnActivation") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 activation))) :*: ((S1 ('MetaSel ('Just "ffnActivationDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 activationDropout) :*: S1 ('MetaSel ('Just "ffnOutputProjection") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outputProjection)) :*: (S1 ('MetaSel ('Just "ffnOutputDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outputDropout) :*: S1 ('MetaSel ('Just "ffnOutputLayerNorm") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outputLayerNorm))))) | |
type ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # | |
Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork type ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) = GTransformerFeedForwardNetwork (ModelSpec inputLayerNorm) (ModelSpec inputTransformation) (ModelSpec activation) (ModelSpec activationDropout) (ModelSpec outputProjection) (ModelSpec outputDropout) (ModelSpec outputLayerNorm) |
type family GTransformerFeedForwardNetworkF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ... Source #
GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout = GTransformerFeedForwardNetwork (FFNInputLayerNormF style gradient device dataType queryEmbedDim) (FFNInputTransformationF style gradient device dataType queryEmbedDim ffnDim) (FFNActivationF style) (FFNActivationDropoutF style hasDropout) (FFNOutputProjectionF style gradient device dataType queryEmbedDim ffnDim) (FFNOutputDropoutF style hasDropout) (FFNOutputLayerNormF style gradient device dataType queryEmbedDim) |
type family FFNInputLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the layer normalization for the input.
FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])) | |
FFNInputLayerNormF 'ByT5 gradient device dataType queryEmbedDim = FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim | |
FFNInputLayerNormF 'BART _ _ _ _ = () | |
FFNInputLayerNormF 'MBART gradient device dataType queryEmbedDim = FFNInputLayerNormF 'BART gradient device dataType queryEmbedDim | |
FFNInputLayerNormF 'Pegasus gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) | |
FFNInputLayerNormF 'BERT _ _ _ _ = () | |
FFNInputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim = FFNInputLayerNormF 'BERT gradient device dataType queryEmbedDim |
type family FFNInputTransformationF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the first input projection.
FFNInputTransformationF 'T5 gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim) | |
FFNInputTransformationF 'ByT5 gradient device dataType queryEmbedDim ffnDim = GGate (NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim)) GeluNew (NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim)) | |
FFNInputTransformationF _ gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim ffnDim) |
type family FFNActivationF (style :: TransformerStyle) :: Type where ... Source #
Specifies the activation.
FFNActivationF 'T5 = Relu | |
FFNActivationF 'ByT5 = GeluNew | |
FFNActivationF 'BART = Gelu | |
FFNActivationF 'MBART = Gelu | |
FFNActivationF 'Pegasus = Relu | |
FFNActivationF 'BERT = Gelu | |
FFNActivationF 'RoBERTa = Gelu |
type family FFNActivationDropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ... Source #
Specifies the activation dropout.
FFNActivationDropoutF 'T5 'WithDropout = Dropout | |
FFNActivationDropoutF 'ByT5 hasDropout = FFNActivationDropoutF 'T5 hasDropout | |
FFNActivationDropoutF 'BART 'WithDropout = Dropout | |
FFNActivationDropoutF 'MBART hasDropout = FFNActivationDropoutF 'BART hasDropout | |
FFNActivationDropoutF 'Pegasus hasDropout = FFNActivationDropoutF 'BART hasDropout | |
FFNActivationDropoutF 'BERT _ = () | |
FFNActivationDropoutF 'RoBERTa hasDropout = FFNActivationDropoutF 'BERT hasDropout | |
FFNActivationDropoutF _ 'WithoutDropout = () |
type family FFNOutputProjectionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the output projection.
FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithoutBias gradient device dataType ffnDim queryEmbedDim) | |
FFNOutputProjectionF 'ByT5 gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim | |
FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim) | |
FFNOutputProjectionF 'MBART gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim | |
FFNOutputProjectionF 'Pegasus gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim | |
FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim) | |
FFNOutputProjectionF 'RoBERTa gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim |
type family FFNOutputDropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ... Source #
Specifies the dropout for the output.
FFNOutputDropoutF _ 'WithDropout = Dropout | |
FFNOutputDropoutF _ 'WithoutDropout = () |
type family FFNOutputLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #
Specifies the layer normalization for the output.
FFNOutputLayerNormF 'T5 _ _ _ _ = () | |
FFNOutputLayerNormF 'ByT5 gradient device dataType queryEmbedDim = FFNOutputLayerNormF 'T5 gradient device dataType queryEmbedDim | |
FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) | |
FFNOutputLayerNormF 'MBART gradient device dataType queryEmbedDim = FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim | |
FFNOutputLayerNormF 'Pegasus _ _ _ _ = () | |
FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) | |
FFNOutputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim = FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim |
transformerFeedForwardNetworkSpec :: forall style gradient device dataType queryEmbedDim ffnDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim queryEmbedDim -> SDim ffnDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout) Source #
Specifies the parameters of the transformer feed forward network.
style
: the style of the transformer feed forward network, e.g.ST5
,SByT5
, etc.gradient
: whether to compute the gradient of the network's parameters.device
: the computational device on which the parameters are allocated.dataType
: the data type of the parameters.queryEmbedDim
: the dimension of the query embedding.ffnDim
: the dimension of the feed forward network's hidden state.dropoutP
: the dropout rate.eps
: the epsilon value for numerical stability of the layer normalization.