hasktorch-gradually-typed-0.2.0.0: experimental project for hasktorch
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Synopsis

Documentation

data GGate (layer0 :: Type) (activation :: Type) (layer1 :: Type) where Source #

Generic two-layer gate with activation function.

  • layer0 is the first layer.
  • activation is the activation function.
  • layer1 is the second layer.

Constructors

GGate 

Fields

  • :: forall layer0 activation layer1. { gateLayer0 :: layer0

    first gate layer

  •    , gateActivation :: activation

    gate activation

  •    , gateLayer1 :: layer1

    second gate layer

  •    } -> GGate layer0 activation layer1
     

Instances

Instances details
Generic (GGate layer0 activation layer1) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Associated Types

type Rep (GGate layer0 activation layer1) :: Type -> Type Source #

Methods

from :: GGate layer0 activation layer1 -> Rep (GGate layer0 activation layer1) x Source #

to :: Rep (GGate layer0 activation layer1) x -> GGate layer0 activation layer1 Source #

(Show layer0, Show activation, Show layer1) => Show (GGate layer0 activation layer1) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

showsPrec :: Int -> GGate layer0 activation layer1 -> ShowS Source #

show :: GGate layer0 activation layer1 -> String Source #

showList :: [GGate layer0 activation layer1] -> ShowS Source #

(Eq layer0, Eq activation, Eq layer1) => Eq (GGate layer0 activation layer1) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

(==) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source #

(/=) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source #

(Ord layer0, Ord activation, Ord layer1) => Ord (GGate layer0 activation layer1) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

compare :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Ordering Source #

(<) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source #

(<=) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source #

(>) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source #

(>=) :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> Bool Source #

max :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> GGate layer0 activation layer1 Source #

min :: GGate layer0 activation layer1 -> GGate layer0 activation layer1 -> GGate layer0 activation layer1 Source #

(HasStateDict layer0, HasStateDict activation, HasStateDict layer1) => HasStateDict (GGate layer0 activation layer1) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GGate layer0 activation layer1) -> StateDictKey -> m (GGate layer0 activation layer1) Source #

toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GGate layer0 activation layer1 -> m () Source #

(HasInitialize layer0 generatorDevice layer0' generatorDevice0, HasInitialize activation generatorDevice0 activation' generatorDevice1, HasInitialize layer1 generatorDevice1 layer1' generatorOutputDevice) => HasInitialize (GGate layer0 activation layer1) generatorDevice (GGate layer0' activation' layer1') generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

initialize :: MonadThrow m => ModelSpec (GGate layer0 activation layer1) -> Generator generatorDevice -> m (GGate layer0' activation' layer1', Generator generatorOutputDevice) Source #

(HasForward layer0 (Tensor gradient layout device dataType shape) generatorDevice (Tensor gradient' layout' device' dataType' shape') generatorDevice', HasForward activation (Tensor gradient' layout' device' dataType' shape') generatorDevice' (Tensor gradient' layout' device' dataType' shape') generatorDevice', HasForward layer1 (Tensor gradient layout device dataType shape) generatorDevice' (Tensor gradient' layout' device' dataType' shape') generatorDevice'', output ~ Tensor gradient' layout' device' dataType' shape', generatorOutputDevice ~ generatorDevice'') => HasForward (GGate layer0 activation layer1) (Tensor gradient layout device dataType shape) generatorDevice output generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

forward :: MonadThrow m => GGate layer0 activation layer1 -> Tensor gradient layout device dataType shape -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

type Rep (GGate layer0 activation layer1) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

type Rep (GGate layer0 activation layer1) = D1 ('MetaData "GGate" "Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GGate" 'PrefixI 'True) (S1 ('MetaSel ('Just "gateLayer0") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 layer0) :*: (S1 ('MetaSel ('Just "gateActivation") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 activation) :*: S1 ('MetaSel ('Just "gateLayer1") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 layer1))))
type ModelSpec (GGate layer0 activation layer1) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

type ModelSpec (GGate layer0 activation layer1) = GGate (ModelSpec layer0) (ModelSpec activation) (ModelSpec layer1)

data GTransformerFeedForwardNetwork (inputLayerNorm :: Type) (inputTransformation :: Type) (activation :: Type) (activationDropout :: Type) (outputProjection :: Type) (outputDropout :: Type) (outputLayerNorm :: Type) where Source #

Generic transformer feed-forward network.

  • inputLayerNorm is the layer normalization for the input.
  • inputTransformation is the input transformation.
  • activation is the activation function.
  • activationDropout is the activation dropout layer.
  • outputProjection is the output projection.
  • outputDropout is the dropout layer for the output.
  • outputLayerNorm is the layer normalization for the output.

Constructors

GTransformerFeedForwardNetwork 

Fields

Instances

Instances details
Generic (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Associated Types

type Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) :: Type -> Type Source #

Methods

from :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) x Source #

to :: Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) x -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm Source #

(Show inputLayerNorm, Show inputTransformation, Show activation, Show activationDropout, Show outputProjection, Show outputDropout, Show outputLayerNorm) => Show (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

showsPrec :: Int -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> ShowS Source #

show :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> String Source #

showList :: [GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm] -> ShowS Source #

(Eq inputLayerNorm, Eq inputTransformation, Eq activation, Eq activationDropout, Eq outputProjection, Eq outputDropout, Eq outputLayerNorm) => Eq (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

(==) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source #

(/=) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source #

(Ord inputLayerNorm, Ord inputTransformation, Ord activation, Ord activationDropout, Ord outputProjection, Ord outputDropout, Ord outputLayerNorm) => Ord (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

compare :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Ordering Source #

(<) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source #

(<=) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source #

(>) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source #

(>=) :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Bool Source #

max :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm Source #

min :: GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm Source #

(HasStateDict inputLayerNorm, HasStateDict inputTransformation, HasStateDict activation, HasStateDict activationDropout, HasStateDict outputProjection, HasStateDict outputDropout, HasStateDict outputLayerNorm) => HasStateDict (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

fromStateDict :: (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) -> StateDictKey -> m (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source #

toStateDict :: (MonadThrow m, MonadState StateDict m) => StateDictKey -> GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> m () Source #

(HasInitialize inputLayerNorm generatorDevice inputLayerNorm' generatorDevice0, HasInitialize inputTransformation generatorDevice0 inputTransformation' generatorDevice1, HasInitialize activation generatorDevice1 activation' generatorDevice2, HasInitialize activationDropout generatorDevice2 activationDropout' generatorDevice3, HasInitialize outputProjection generatorDevice3 outputProjection' generatorDevice4, HasInitialize outputDropout generatorDevice4 outputDropout' generatorDevice5, HasInitialize outputLayerNorm generatorDevice5 outputLayerNorm' generatorOutputDevice) => HasInitialize (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) generatorDevice (GTransformerFeedForwardNetwork inputLayerNorm' inputTransformation' activation' activationDropout' outputProjection' outputDropout' outputLayerNorm') generatorOutputDevice Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

initialize :: MonadThrow m => ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) -> Generator generatorDevice -> m (GTransformerFeedForwardNetwork inputLayerNorm' inputTransformation' activation' activationDropout' outputProjection' outputDropout' outputLayerNorm', Generator generatorOutputDevice) Source #

(HasForward inputLayerNorm (Tensor queryGradient queryLayout queryDevice queryDataType queryShape) generatorDevice tensor0 generatorDevice0, HasForward inputTransformation tensor0 generatorDevice0 tensor1 generatorDevice1, HasForward activation tensor1 generatorDevice1 tensor2 generatorDevice2, HasForward activationDropout tensor2 generatorDevice2 tensor3 generatorDevice3, HasForward outputProjection tensor3 generatorDevice3 tensor4 generatorDevice4, HasForward outputDropout tensor4 generatorDevice4 (Tensor queryGradient5 queryLayout5 queryDevice5 queryDataType5 queryShape5) generatorDevice5, HasForward outputLayerNorm (Tensor (queryGradient <|> queryGradient5) (queryLayout <+> queryLayout5) (queryDevice <+> queryDevice5) (queryDataType <+> queryDataType5) (BroadcastShapesF queryShape queryShape5)) generatorDevice5 output generatorOutputDevice, Catch (BroadcastShapesF queryShape queryShape5)) => HasForward (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) (Tensor queryGradient queryLayout queryDevice queryDataType queryShape) generatorDevice output generatorOutputDevice Source #

HasForward instance for GTransformerFeedForwardNetwork.

      ┌───────┐
      │ query ├────────┐
      └───┬───┘        │
          │            │
          ▼            │
 (ffnInputLayerNorm)   │
          ▼            │
ffnInputTransformation │
          ▼            │
    ffnActivation      │
          ▼            │
(ffnActivationDropout) │
          ▼            │
  ffnOutputProjecton   │
          ▼            │
   ffnOutputDropout    │
          │            │
          ▼            │
         add◄──────────┘
          │
          ▼
 (ffnOutputLayerNorm)
          │
          ▼
      ┌───────┐
      │ query │
      └───────┘
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

Methods

forward :: MonadThrow m => GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm -> Tensor queryGradient queryLayout queryDevice queryDataType queryShape -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) Source #

type Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

type Rep (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) = D1 ('MetaData "GTransformerFeedForwardNetwork" "Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork" "hasktorch-gradually-typed-0.2.0.0-1KV1aIPzzbp6JpSr37tC1K" 'False) (C1 ('MetaCons "GTransformerFeedForwardNetwork" 'PrefixI 'True) ((S1 ('MetaSel ('Just "ffnInputLayerNorm") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 inputLayerNorm) :*: (S1 ('MetaSel ('Just "ffnInputTransformation") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 inputTransformation) :*: S1 ('MetaSel ('Just "ffnActivation") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 activation))) :*: ((S1 ('MetaSel ('Just "ffnActivationDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 activationDropout) :*: S1 ('MetaSel ('Just "ffnOutputProjection") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outputProjection)) :*: (S1 ('MetaSel ('Just "ffnOutputDropout") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outputDropout) :*: S1 ('MetaSel ('Just "ffnOutputLayerNorm") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedLazy) (Rec0 outputLayerNorm)))))
type ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) Source # 
Instance details

Defined in Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork

type ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) = GTransformerFeedForwardNetwork (ModelSpec inputLayerNorm) (ModelSpec inputTransformation) (ModelSpec activation) (ModelSpec activationDropout) (ModelSpec outputProjection) (ModelSpec outputDropout) (ModelSpec outputLayerNorm)

type family GTransformerFeedForwardNetworkF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) (hasDropout :: HasDropout) :: Type where ... Source #

Equations

GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout = GTransformerFeedForwardNetwork (FFNInputLayerNormF style gradient device dataType queryEmbedDim) (FFNInputTransformationF style gradient device dataType queryEmbedDim ffnDim) (FFNActivationF style) (FFNActivationDropoutF style hasDropout) (FFNOutputProjectionF style gradient device dataType queryEmbedDim ffnDim) (FFNOutputDropoutF style hasDropout) (FFNOutputLayerNormF style gradient device dataType queryEmbedDim) 

type family FFNInputLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the layer normalization for the input.

Equations

FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])) 
FFNInputLayerNormF 'ByT5 gradient device dataType queryEmbedDim = FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim 
FFNInputLayerNormF 'BART _ _ _ _ = () 
FFNInputLayerNormF 'MBART gradient device dataType queryEmbedDim = FFNInputLayerNormF 'BART gradient device dataType queryEmbedDim 
FFNInputLayerNormF 'Pegasus gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) 
FFNInputLayerNormF 'BERT _ _ _ _ = () 
FFNInputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim = FFNInputLayerNormF 'BERT gradient device dataType queryEmbedDim 

type family FFNInputTransformationF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the first input projection.

Equations

FFNInputTransformationF 'T5 gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim) 
FFNInputTransformationF 'ByT5 gradient device dataType queryEmbedDim ffnDim = GGate (NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim)) GeluNew (NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim)) 
FFNInputTransformationF _ gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim ffnDim) 

type family FFNOutputProjectionF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) (ffnDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the output projection.

Equations

FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithoutBias gradient device dataType ffnDim queryEmbedDim) 
FFNOutputProjectionF 'ByT5 gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim 
FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim) 
FFNOutputProjectionF 'MBART gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim 
FFNOutputProjectionF 'Pegasus gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim 
FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim = NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim) 
FFNOutputProjectionF 'RoBERTa gradient device dataType queryEmbedDim ffnDim = FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim 

type family FFNOutputDropoutF (style :: TransformerStyle) (hasDropout :: HasDropout) :: Type where ... Source #

Specifies the dropout for the output.

type family FFNOutputLayerNormF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where ... Source #

Specifies the layer normalization for the output.

Equations

FFNOutputLayerNormF 'T5 _ _ _ _ = () 
FFNOutputLayerNormF 'ByT5 gradient device dataType queryEmbedDim = FFNOutputLayerNormF 'T5 gradient device dataType queryEmbedDim 
FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) 
FFNOutputLayerNormF 'MBART gradient device dataType queryEmbedDim = FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim 
FFNOutputLayerNormF 'Pegasus _ _ _ _ = () 
FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim])) 
FFNOutputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim = FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim 

transformerFeedForwardNetworkSpec :: forall style gradient device dataType queryEmbedDim ffnDim hasDropout. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim queryEmbedDim -> SDim ffnDim -> SHasDropout hasDropout -> Double -> Double -> ModelSpec (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout) Source #

Specifies the parameters of the transformer feed forward network.

  • style: the style of the transformer feed forward network, e.g. ST5, SByT5, etc.
  • gradient: whether to compute the gradient of the network's parameters.
  • device: the computational device on which the parameters are allocated.
  • dataType: the data type of the parameters.
  • queryEmbedDim: the dimension of the query embedding.
  • ffnDim: the dimension of the feed forward network's hidden state.
  • dropoutP: the dropout rate.
  • eps: the epsilon value for numerical stability of the layer normalization.