{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork where

import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Activation (Gelu (..), GeluNew (..), Relu (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Linear (GLinear (..), GLinearF, linearSpec)
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, SShape (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

-- | Generic two-layer gate with activation function.
--
-- - @layer0@ is the first layer.
-- - @activation@ is the activation function.
-- - @layer1@ is the second layer.
data
  GGate
    (layer0 :: Type)
    (activation :: Type)
    (layer1 :: Type)
  where
  GGate ::
    forall layer0 activation layer1.
    { -- | first gate layer
      forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer0
gateLayer0 :: layer0,
      -- | gate activation
      forall layer0 activation layer1.
GGate layer0 activation layer1 -> activation
gateActivation :: activation,
      -- | second gate layer
      forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer1
gateLayer1 :: layer1
    } ->
    GGate layer0 activation layer1
  deriving stock (GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall layer0 activation layer1.
(Eq layer0, Eq activation, Eq layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
/= :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c/= :: forall layer0 activation layer1.
(Eq layer0, Eq activation, Eq layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
== :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c== :: forall layer0 activation layer1.
(Eq layer0, Eq activation, Eq layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
Eq, GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {layer0} {activation} {layer1}.
(Ord layer0, Ord activation, Ord layer1) =>
Eq (GGate layer0 activation layer1)
forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
min :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
$cmin :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
max :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
$cmax :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
>= :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c>= :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
> :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c> :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
<= :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c<= :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
< :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c< :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
compare :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
$ccompare :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
Ord, Int -> GGate layer0 activation layer1 -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
Int -> GGate layer0 activation layer1 -> ShowS
forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
[GGate layer0 activation layer1] -> ShowS
forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
GGate layer0 activation layer1 -> String
showList :: [GGate layer0 activation layer1] -> ShowS
$cshowList :: forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
[GGate layer0 activation layer1] -> ShowS
show :: GGate layer0 activation layer1 -> String
$cshow :: forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
GGate layer0 activation layer1 -> String
showsPrec :: Int -> GGate layer0 activation layer1 -> ShowS
$cshowsPrec :: forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
Int -> GGate layer0 activation layer1 -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall layer0 activation layer1 x.
Rep (GGate layer0 activation layer1) x
-> GGate layer0 activation layer1
forall layer0 activation layer1 x.
GGate layer0 activation layer1
-> Rep (GGate layer0 activation layer1) x
$cto :: forall layer0 activation layer1 x.
Rep (GGate layer0 activation layer1) x
-> GGate layer0 activation layer1
$cfrom :: forall layer0 activation layer1 x.
GGate layer0 activation layer1
-> Rep (GGate layer0 activation layer1) x
Generic)

type instance
  ModelSpec (GGate layer0 activation layer1) =
    GGate (ModelSpec layer0) (ModelSpec activation) (ModelSpec layer1)

instance
  ( HasInitialize layer0 generatorDevice layer0' generatorDevice0,
    HasInitialize activation generatorDevice0 activation' generatorDevice1,
    HasInitialize layer1 generatorDevice1 layer1' generatorOutputDevice
  ) =>
  HasInitialize
    (GGate layer0 activation layer1)
    generatorDevice
    (GGate layer0' activation' layer1')
    generatorOutputDevice

instance
  (HasStateDict layer0, HasStateDict activation, HasStateDict layer1) =>
  HasStateDict (GGate layer0 activation layer1)

instance
  ( HasForward
      layer0
      (Tensor gradient layout device dataType shape)
      generatorDevice
      (Tensor gradient' layout' device' dataType' shape')
      generatorDevice',
    HasForward
      activation
      (Tensor gradient' layout' device' dataType' shape')
      generatorDevice'
      (Tensor gradient' layout' device' dataType' shape')
      generatorDevice',
    HasForward
      layer1
      (Tensor gradient layout device dataType shape)
      generatorDevice'
      (Tensor gradient' layout' device' dataType' shape')
      generatorDevice'',
    output ~ Tensor gradient' layout' device' dataType' shape',
    generatorOutputDevice ~ generatorDevice''
  ) =>
  HasForward
    (GGate layer0 activation layer1)
    (Tensor gradient layout device dataType shape)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GGate layer0 activation layer1
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GGate {layer0
activation
layer1
gateLayer1 :: layer1
gateActivation :: activation
gateLayer0 :: layer0
gateLayer1 :: forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer1
gateActivation :: forall layer0 activation layer1.
GGate layer0 activation layer1 -> activation
gateLayer0 :: forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer0
..} Tensor gradient layout device dataType shape
input =
    let activate :: Tensor gradient layout device dataType shape
-> IxStateT
     m (Generator generatorDevice) (Generator generatorDevice') output
activate Tensor gradient layout device dataType shape
input' =
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor gradient layout device dataType shape
input'
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward layer0
gateLayer0
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activation
gateActivation
        gate :: Tensor gradient layout device dataType shape
-> IxStateT
     m
     (Generator generatorDevice)
     (Generator generatorOutputDevice)
     output
gate Tensor gradient layout device dataType shape
input' = forall a. Num a => a -> a -> a
(*) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> Tensor gradient layout device dataType shape
-> IxStateT
     m (Generator generatorDevice) (Generator generatorDevice') output
activate Tensor gradient layout device dataType shape
input' forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> (forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward layer1
gateLayer1 forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
input')
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor gradient layout device dataType shape
input forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= Tensor gradient layout device dataType shape
-> IxStateT
     m
     (Generator generatorDevice)
     (Generator generatorOutputDevice)
     output
gate

-- | Generic transformer feed-forward network.
--
-- - @inputLayerNorm@ is the layer normalization for the input.
-- - @inputTransformation@ is the input transformation.
-- - @activation@ is the activation function.
-- - @activationDropout@ is the activation dropout layer.
-- - @outputProjection@ is the output projection.
-- - @outputDropout@ is the dropout layer for the output.
-- - @outputLayerNorm@ is the layer normalization for the output.
data
  GTransformerFeedForwardNetwork
    (inputLayerNorm :: Type)
    (inputTransformation :: Type)
    (activation :: Type)
    (activationDropout :: Type)
    (outputProjection :: Type)
    (outputDropout :: Type)
    (outputLayerNorm :: Type)
  where
  GTransformerFeedForwardNetwork ::
    forall inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm.
    { -- | input layer norm
      forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> inputLayerNorm
ffnInputLayerNorm :: inputLayerNorm,
      -- | input transformation
      forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> inputTransformation
ffnInputTransformation :: inputTransformation,
      -- | activation
      forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> activation
ffnActivation :: activation,
      -- | activation dropout
      forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> activationDropout
ffnActivationDropout :: activationDropout,
      -- | output projection
      forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> outputProjection
ffnOutputProjection :: outputProjection,
      -- | output dropout
      forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> outputDropout
ffnOutputDropout :: outputDropout,
      -- | output layer norm
      forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> outputLayerNorm
ffnOutputLayerNorm :: outputLayerNorm
    } ->
    GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm
  deriving stock (GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Eq inputLayerNorm, Eq inputTransformation, Eq activation,
 Eq activationDropout, Eq outputProjection, Eq outputDropout,
 Eq outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
/= :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
$c/= :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Eq inputLayerNorm, Eq inputTransformation, Eq activation,
 Eq activationDropout, Eq outputProjection, Eq outputDropout,
 Eq outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
== :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
$c== :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Eq inputLayerNorm, Eq inputTransformation, Eq activation,
 Eq activationDropout, Eq outputProjection, Eq outputDropout,
 Eq outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
Eq, GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {inputLayerNorm} {inputTransformation} {activation}
       {activationDropout} {outputProjection} {outputDropout}
       {outputLayerNorm}.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
Eq
  (GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm)
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Ordering
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
min :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
$cmin :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
max :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
$cmax :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
>= :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
$c>= :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
> :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
$c> :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
<= :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
$c<= :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
< :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
$c< :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Bool
compare :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Ordering
$ccompare :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
 Ord activationDropout, Ord outputProjection, Ord outputDropout,
 Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> Ordering
Ord, Int
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
 Show activationDropout, Show outputProjection, Show outputDropout,
 Show outputLayerNorm) =>
Int
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> ShowS
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
 Show activationDropout, Show outputProjection, Show outputDropout,
 Show outputLayerNorm) =>
[GTransformerFeedForwardNetwork
   inputLayerNorm
   inputTransformation
   activation
   activationDropout
   outputProjection
   outputDropout
   outputLayerNorm]
-> ShowS
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
 Show activationDropout, Show outputProjection, Show outputDropout,
 Show outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> String
showList :: [GTransformerFeedForwardNetwork
   inputLayerNorm
   inputTransformation
   activation
   activationDropout
   outputProjection
   outputDropout
   outputLayerNorm]
-> ShowS
$cshowList :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
 Show activationDropout, Show outputProjection, Show outputDropout,
 Show outputLayerNorm) =>
[GTransformerFeedForwardNetwork
   inputLayerNorm
   inputTransformation
   activation
   activationDropout
   outputProjection
   outputDropout
   outputLayerNorm]
-> ShowS
show :: GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> String
$cshow :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
 Show activationDropout, Show outputProjection, Show outputDropout,
 Show outputLayerNorm) =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> String
showsPrec :: Int
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> ShowS
$cshowsPrec :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
 Show activationDropout, Show outputProjection, Show outputDropout,
 Show outputLayerNorm) =>
Int
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm x.
Rep
  (GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm)
  x
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm x.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> Rep
     (GTransformerFeedForwardNetwork
        inputLayerNorm
        inputTransformation
        activation
        activationDropout
        outputProjection
        outputDropout
        outputLayerNorm)
     x
$cto :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm x.
Rep
  (GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm)
  x
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
$cfrom :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm x.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> Rep
     (GTransformerFeedForwardNetwork
        inputLayerNorm
        inputTransformation
        activation
        activationDropout
        outputProjection
        outputDropout
        outputLayerNorm)
     x
Generic)

type instance
  ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) =
    GTransformerFeedForwardNetwork (ModelSpec inputLayerNorm) (ModelSpec inputTransformation) (ModelSpec activation) (ModelSpec activationDropout) (ModelSpec outputProjection) (ModelSpec outputDropout) (ModelSpec outputLayerNorm)

type family
  GTransformerFeedForwardNetworkF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout =
    GTransformerFeedForwardNetwork
      (FFNInputLayerNormF style gradient device dataType queryEmbedDim)
      (FFNInputTransformationF style gradient device dataType queryEmbedDim ffnDim)
      (FFNActivationF style)
      (FFNActivationDropoutF style hasDropout)
      (FFNOutputProjectionF style gradient device dataType queryEmbedDim ffnDim)
      (FFNOutputDropoutF style hasDropout)
      (FFNOutputLayerNormF style gradient device dataType queryEmbedDim)

-- | Specifies the layer normalization for the input.
type family
  FFNInputLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim =
    NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[queryEmbedDim]))
  FFNInputLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
    FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim
  FFNInputLayerNormF 'BART _ _ _ _ =
    ()
  FFNInputLayerNormF 'MBART gradient device dataType queryEmbedDim =
    FFNInputLayerNormF 'BART gradient device dataType queryEmbedDim
  FFNInputLayerNormF 'Pegasus gradient device dataType queryEmbedDim =
    NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
  FFNInputLayerNormF 'BERT _ _ _ _ =
    ()
  FFNInputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim =
    FFNInputLayerNormF 'BERT gradient device dataType queryEmbedDim

-- | Specifies the first input projection.
type family
  FFNInputTransformationF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  FFNInputTransformationF 'T5 gradient device dataType queryEmbedDim ffnDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim)
  FFNInputTransformationF 'ByT5 gradient device dataType queryEmbedDim ffnDim =
    GGate
      (NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim))
      GeluNew
      (NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim))
  FFNInputTransformationF _ gradient device dataType queryEmbedDim ffnDim =
    NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim ffnDim)

-- | Specifies the activation.
type family
  FFNActivationF
    (style :: TransformerStyle) ::
    Type
  where
  FFNActivationF 'T5 = Relu
  FFNActivationF 'ByT5 = GeluNew
  FFNActivationF 'BART = Gelu
  FFNActivationF 'MBART = Gelu
  FFNActivationF 'Pegasus = Relu
  FFNActivationF 'BERT = Gelu
  FFNActivationF 'RoBERTa = Gelu

-- | Specifies the activation dropout.
type family
  FFNActivationDropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  FFNActivationDropoutF 'T5 'WithDropout = Dropout
  FFNActivationDropoutF 'ByT5 hasDropout = FFNActivationDropoutF 'T5 hasDropout
  FFNActivationDropoutF 'BART 'WithDropout = Dropout
  FFNActivationDropoutF 'MBART hasDropout = FFNActivationDropoutF 'BART hasDropout
  FFNActivationDropoutF 'Pegasus hasDropout = FFNActivationDropoutF 'BART hasDropout
  FFNActivationDropoutF 'BERT _ = ()
  FFNActivationDropoutF 'RoBERTa hasDropout = FFNActivationDropoutF 'BERT hasDropout
  FFNActivationDropoutF _ 'WithoutDropout = ()

-- | Specifies the output projection.
type family
  FFNOutputProjectionF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType ffnDim queryEmbedDim)
  FFNOutputProjectionF 'ByT5 gradient device dataType queryEmbedDim ffnDim =
    FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim
  FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim =
    NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim)
  FFNOutputProjectionF 'MBART gradient device dataType queryEmbedDim ffnDim =
    FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim
  FFNOutputProjectionF 'Pegasus gradient device dataType queryEmbedDim ffnDim =
    FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim
  FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim =
    NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim)
  FFNOutputProjectionF 'RoBERTa gradient device dataType queryEmbedDim ffnDim =
    FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim

-- | Specifies the dropout for the output.
type family
  FFNOutputDropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  FFNOutputDropoutF _ 'WithDropout = Dropout
  FFNOutputDropoutF _ 'WithoutDropout = ()

-- | Specifies the layer normalization for the output.
type family
  FFNOutputLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  FFNOutputLayerNormF 'T5 _ _ _ _ =
    ()
  FFNOutputLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
    FFNOutputLayerNormF 'T5 gradient device dataType queryEmbedDim
  FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim =
    NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
  FFNOutputLayerNormF 'MBART gradient device dataType queryEmbedDim =
    FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim
  FFNOutputLayerNormF 'Pegasus _ _ _ _ =
    ()
  FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim =
    NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
  FFNOutputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim =
    FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim

-- | Specifies the parameters of the transformer feed forward network.
--
-- - @style@: the style of the transformer feed forward network, e.g. 'ST5', 'SByT5', etc.
-- - @gradient@: whether to compute the gradient of the network's parameters.
-- - @device@: the computational device on which the parameters are allocated.
-- - @dataType@: the data type of the parameters.
-- - @queryEmbedDim@: the dimension of the query embedding.
-- - @ffnDim@: the dimension of the feed forward network's hidden state.
-- - @dropoutP@: the dropout rate.
-- - @eps@: the epsilon value for numerical stability of the layer normalization.
transformerFeedForwardNetworkSpec ::
  forall style gradient device dataType queryEmbedDim ffnDim hasDropout.
  STransformerStyle style ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim queryEmbedDim ->
  SDim ffnDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec :: forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GTransformerFeedForwardNetworkF
        style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let inputLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (FFNInputLayerNormF style gradient device dataType queryEmbedDim)
inputLayerNormSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"layer_norm." LayerNormSpec
  'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
      inputLayerNormSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"layer_norm." LayerNormSpec
  'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
      inputLayerNormSpec STransformerStyle style
SBART = ()
      inputLayerNormSpec STransformerStyle style
SMBART = ()
      inputLayerNormSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_layer_norm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      inputLayerNormSpec STransformerStyle style
SBERT = ()
      inputLayerNormSpec STransformerStyle style
SRoBERTa = ()
      inputLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      inputTransformationSpec :: STransformerStyle style
-> ModelSpec
     (FFNInputTransformationF
        style gradient device dataType queryEmbedDim ffnDim)
inputTransformationSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wi." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
weightSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
      inputTransformationSpec STransformerStyle style
SByT5 =
        forall layer0 activation layer1.
layer0 -> activation -> layer1 -> GGate layer0 activation layer1
GGate
          (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wi_0." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
weightSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim)
          GeluNew
GeluNew
          (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wi_1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
weightSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim)
      inputTransformationSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
      inputTransformationSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
      inputTransformationSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
      inputTransformationSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"intermediate.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
      inputTransformationSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"intermediate.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
      inputTransformationSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      activationSpec :: STransformerStyle style -> ModelSpec (FFNActivationF style)
      activationSpec :: STransformerStyle style -> ModelSpec (FFNActivationF style)
activationSpec STransformerStyle style
ST5 = Relu
Relu
      activationSpec STransformerStyle style
SByT5 = GeluNew
GeluNew
      activationSpec STransformerStyle style
SBART = Gelu
Gelu
      activationSpec STransformerStyle style
SMBART = Gelu
Gelu
      activationSpec STransformerStyle style
SPegasus = Relu
Relu
      activationSpec STransformerStyle style
SBERT = Gelu
Gelu
      activationSpec STransformerStyle style
SRoBERTa = Gelu
Gelu
      activationSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      activationDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNActivationDropoutF style hasDropout)
activationDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      activationDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
      activationDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      activationDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
      activationDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      activationDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithoutDropout = ()
      activationDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      activationDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithoutDropout = ()
      activationDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      activationDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithoutDropout = ()
      activationDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = ()
      activationDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = ()
      activationDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
      outputProjectionSpec :: STransformerStyle style
-> ModelSpec
     (FFNOutputProjectionF
        style gradient device dataType queryEmbedDim ffnDim)
outputProjectionSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wo." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
weightSpecWithoutBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
      outputProjectionSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wo." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
weightSpecWithoutBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
      outputProjectionSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc2." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
      outputProjectionSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc2." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
      outputProjectionSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc2." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
      outputProjectionSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
      outputProjectionSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
      outputProjectionSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      outputDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNOutputDropoutF style hasDropout)
outputDropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      outputDropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithoutDropout = ()
      outputLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (FFNOutputLayerNormF style gradient device dataType queryEmbedDim)
outputLayerNormSpec STransformerStyle style
ST5 = ()
      outputLayerNormSpec STransformerStyle style
SByT5 = ()
      outputLayerNormSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_layer_norm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      outputLayerNormSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_layer_norm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      outputLayerNormSpec STransformerStyle style
SPegasus = ()
      outputLayerNormSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.LayerNorm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      outputLayerNormSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.LayerNorm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      outputLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
inputLayerNorm
-> inputTransformation
-> activation
-> activationDropout
-> outputProjection
-> outputDropout
-> outputLayerNorm
-> GTransformerFeedForwardNetwork
     inputLayerNorm
     inputTransformation
     activation
     activationDropout
     outputProjection
     outputDropout
     outputLayerNorm
GTransformerFeedForwardNetwork
        (STransformerStyle style
-> ModelSpec
     (FFNInputLayerNormF style gradient device dataType queryEmbedDim)
inputLayerNormSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (FFNInputTransformationF
        style gradient device dataType queryEmbedDim ffnDim)
inputTransformationSpec STransformerStyle style
style)
        (STransformerStyle style -> ModelSpec (FFNActivationF style)
activationSpec STransformerStyle style
style)
        (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNActivationDropoutF style hasDropout)
activationDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
        (STransformerStyle style
-> ModelSpec
     (FFNOutputProjectionF
        style gradient device dataType queryEmbedDim ffnDim)
outputProjectionSpec STransformerStyle style
style)
        (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNOutputDropoutF style hasDropout)
outputDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
        (STransformerStyle style
-> ModelSpec
     (FFNOutputLayerNormF style gradient device dataType queryEmbedDim)
outputLayerNormSpec STransformerStyle style
style)
  where
    weightSpecWithoutBias ::
      forall inputDim outputDim.
      SDim inputDim ->
      SDim outputDim ->
      ModelSpec
        ( GLinear
            (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
            (NamedModel ())
        )
    weightSpecWithoutBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
weightSpecWithoutBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType
    weightSpecWithBias ::
      forall inputDim outputDim.
      SDim inputDim ->
      SDim outputDim ->
      ModelSpec
        ( GLinear
            (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
            (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
        )
    weightSpecWithBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType
    layerNormWithoutBiasSpec :: LayerNormSpec
  'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
    layerNormWithBiasSpec :: LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps

instance
  ( HasInitialize inputLayerNorm generatorDevice inputLayerNorm' generatorDevice0,
    HasInitialize inputTransformation generatorDevice0 inputTransformation' generatorDevice1,
    HasInitialize activation generatorDevice1 activation' generatorDevice2,
    HasInitialize activationDropout generatorDevice2 activationDropout' generatorDevice3,
    HasInitialize outputProjection generatorDevice3 outputProjection' generatorDevice4,
    HasInitialize outputDropout generatorDevice4 outputDropout' generatorDevice5,
    HasInitialize outputLayerNorm generatorDevice5 outputLayerNorm' generatorOutputDevice
  ) =>
  HasInitialize
    (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm)
    generatorDevice
    (GTransformerFeedForwardNetwork inputLayerNorm' inputTransformation' activation' activationDropout' outputProjection' outputDropout' outputLayerNorm')
    generatorOutputDevice

instance
  ( HasStateDict inputLayerNorm,
    HasStateDict inputTransformation,
    HasStateDict activation,
    HasStateDict activationDropout,
    HasStateDict outputProjection,
    HasStateDict outputDropout,
    HasStateDict outputLayerNorm
  ) =>
  HasStateDict (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm)

-- | 'HasForward' instance for 'GTransformerFeedForwardNetwork'.
--
-- @
--       ┌───────┐
--       │ query ├────────┐
--       └───┬───┘        │
--           │            │
--           ▼            │
--  (ffnInputLayerNorm)   │
--           ▼            │
-- ffnInputTransformation │
--           ▼            │
--     ffnActivation      │
--           ▼            │
-- (ffnActivationDropout) │
--           ▼            │
--   ffnOutputProjecton   │
--           ▼            │
--    ffnOutputDropout    │
--           │            │
--           ▼            │
--          add◄──────────┘
--           │
--           ▼
--  (ffnOutputLayerNorm)
--           │
--           ▼
--       ┌───────┐
--       │ query │
--       └───────┘
-- @
instance
  ( HasForward
      inputLayerNorm
      (Tensor queryGradient queryLayout queryDevice queryDataType queryShape)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      inputTransformation
      tensor0
      generatorDevice0
      tensor1
      generatorDevice1,
    HasForward
      activation
      tensor1
      generatorDevice1
      tensor2
      generatorDevice2,
    HasForward
      activationDropout
      tensor2
      generatorDevice2
      tensor3
      generatorDevice3,
    HasForward
      outputProjection
      tensor3
      generatorDevice3
      tensor4
      generatorDevice4,
    HasForward
      outputDropout
      tensor4
      generatorDevice4
      (Tensor queryGradient5 queryLayout5 queryDevice5 queryDataType5 queryShape5)
      generatorDevice5,
    HasForward
      outputLayerNorm
      (Tensor (queryGradient <|> queryGradient5) (queryLayout <+> queryLayout5) (queryDevice <+> queryDevice5) (queryDataType <+> queryDataType5) (BroadcastShapesF queryShape queryShape5))
      generatorDevice5
      output
      generatorOutputDevice,
    Catch (BroadcastShapesF queryShape queryShape5)
  ) =>
  HasForward
    (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm)
    (Tensor queryGradient queryLayout queryDevice queryDataType queryShape)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> Tensor
     queryGradient queryLayout queryDevice queryDataType queryShape
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformerFeedForwardNetwork {inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
ffnOutputLayerNorm :: outputLayerNorm
ffnOutputDropout :: outputDropout
ffnOutputProjection :: outputProjection
ffnActivationDropout :: activationDropout
ffnActivation :: activation
ffnInputTransformation :: inputTransformation
ffnInputLayerNorm :: inputLayerNorm
ffnOutputLayerNorm :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> outputLayerNorm
ffnOutputDropout :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> outputDropout
ffnOutputProjection :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> outputProjection
ffnActivationDropout :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> activationDropout
ffnActivation :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> activation
ffnInputTransformation :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> inputTransformation
ffnInputLayerNorm :: forall inputLayerNorm inputTransformation activation
       activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
  inputLayerNorm
  inputTransformation
  activation
  activationDropout
  outputProjection
  outputDropout
  outputLayerNorm
-> inputLayerNorm
..} Tensor
  queryGradient queryLayout queryDevice queryDataType queryShape
query =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  queryGradient queryLayout queryDevice queryDataType queryShape
query
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward inputLayerNorm
ffnInputLayerNorm
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward inputTransformation
ffnInputTransformation
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activation
ffnActivation
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activationDropout
ffnActivationDropout
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outputProjection
ffnOutputProjection
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outputDropout
ffnOutputDropout
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
  queryGradient queryLayout queryDevice queryDataType queryShape
query forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`add`)
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outputLayerNorm
ffnOutputLayerNorm