{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}

module Torch.GraduallyTyped.NN.Transformer.GTransformer where

import Control.Monad.Indexed ((>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed (IxPointed (ireturn))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Sparse (Embedding (..), EmbeddingSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GStack (DecoderStackF, EncoderStackF, decoderStackSpec, encoderStackSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (SNothing))
import Torch.GraduallyTyped.Prelude.TypeLits (SNat (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim, SSelectDim (..), SShape (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (TransposeF, UnsqueezeF, sTranspose, sUnsqueeze, transpose, unsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

-- | Generic transformer.
-- Can specialize to either encoder or decoder.
--
-- - @posEnc@: an absolute positional encoding layer as used by, e.g., BERT.
-- - @relPosEnc@: a relative positional encoding layer as used by, e.g., T5.
-- - @initialLayerNorm@: a layer normalization layer for the embeddings.
-- - @initialDropout@: a dropout layer for the embeddings.
-- - @stack@: a stack of transformer blocks.
-- - @finalLayerNorm@: the final layer normalization layer.
-- - @finalDropout@: the final dropout layer.
data
  GTransformer
    (posEnc :: Type)
    (relPosEnc :: Type)
    (initialLayerNorm :: Type)
    (initialDropout :: Type)
    (stack :: Type)
    (finalLayerNorm :: Type)
    (finalDropout :: Type)
  where
  GTransformer ::
    forall posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout.
    { -- | absolute positional encoding
      forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> posEnc
tPosEnc :: posEnc,
      -- | relative positional encoding
      forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> relPosEnc
tRelPosEnc :: relPosEnc,
      -- | initial layer norm
      forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialLayerNorm
tInitialLayerNorm :: initialLayerNorm,
      -- | initial dropout
      forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialDropout
tInitialDropout :: initialDropout,
      -- | transformer block stack
      forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> stack
tStack :: stack,
      -- | final layer norm
      forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalLayerNorm
tFinalLayerNorm :: finalLayerNorm,
      -- | final dropout
      forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalDropout
tFinalDropout :: finalDropout
    } ->
    GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout
  deriving stock (GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Eq posEnc, Eq relPosEnc, Eq initialLayerNorm, Eq initialDropout,
 Eq stack, Eq finalLayerNorm, Eq finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
/= :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
$c/= :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Eq posEnc, Eq relPosEnc, Eq initialLayerNorm, Eq initialDropout,
 Eq stack, Eq finalLayerNorm, Eq finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
== :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
$c== :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Eq posEnc, Eq relPosEnc, Eq initialLayerNorm, Eq initialDropout,
 Eq stack, Eq finalLayerNorm, Eq finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
Eq, GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {posEnc} {relPosEnc} {initialLayerNorm} {initialDropout}
       {stack} {finalLayerNorm} {finalDropout}.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
Eq
  (GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout)
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Ordering
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
min :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
$cmin :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
max :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
$cmax :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
>= :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
$c>= :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
> :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
$c> :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
<= :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
$c<= :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
< :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
$c< :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Bool
compare :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Ordering
$ccompare :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
 Ord initialDropout, Ord stack, Ord finalLayerNorm,
 Ord finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> Ordering
Ord, Int
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
 Show initialDropout, Show stack, Show finalLayerNorm,
 Show finalDropout) =>
Int
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> ShowS
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
 Show initialDropout, Show stack, Show finalLayerNorm,
 Show finalDropout) =>
[GTransformer
   posEnc
   relPosEnc
   initialLayerNorm
   initialDropout
   stack
   finalLayerNorm
   finalDropout]
-> ShowS
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
 Show initialDropout, Show stack, Show finalLayerNorm,
 Show finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> String
showList :: [GTransformer
   posEnc
   relPosEnc
   initialLayerNorm
   initialDropout
   stack
   finalLayerNorm
   finalDropout]
-> ShowS
$cshowList :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
 Show initialDropout, Show stack, Show finalLayerNorm,
 Show finalDropout) =>
[GTransformer
   posEnc
   relPosEnc
   initialLayerNorm
   initialDropout
   stack
   finalLayerNorm
   finalDropout]
-> ShowS
show :: GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> String
$cshow :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
 Show initialDropout, Show stack, Show finalLayerNorm,
 Show finalDropout) =>
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> String
showsPrec :: Int
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> ShowS
$cshowsPrec :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
 Show initialDropout, Show stack, Show finalLayerNorm,
 Show finalDropout) =>
Int
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout x.
Rep
  (GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout)
  x
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout x.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> Rep
     (GTransformer
        posEnc
        relPosEnc
        initialLayerNorm
        initialDropout
        stack
        finalLayerNorm
        finalDropout)
     x
$cto :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout x.
Rep
  (GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout)
  x
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
$cfrom :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout x.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> Rep
     (GTransformer
        posEnc
        relPosEnc
        initialLayerNorm
        initialDropout
        stack
        finalLayerNorm
        finalDropout)
     x
Generic)

type instance
  ModelSpec (GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout) =
    GTransformer (ModelSpec posEnc) (ModelSpec relPosEnc) (ModelSpec initialLayerNorm) (ModelSpec initialDropout) (ModelSpec stack) (ModelSpec finalLayerNorm) (ModelSpec finalDropout)

type family
  TransformerEncoderF
    (style :: TransformerStyle)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  TransformerEncoderF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout =
    GTransformer
      (TEPosEncF style gradient device dataType inputEmbedDim posEncDim)
      (TERelPosEncF style gradient device dataType headDim posEncDim)
      (TEInitialLayerNormF style gradient device dataType inputEmbedDim)
      (TEInitialDropoutF style hasDropout)
      (TEStackF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim hasDropout)
      (TEFinalLayerNormF style gradient device dataType inputEmbedDim)
      (TEFinalDropoutF style hasDropout)

-- | Specifies the absolute positional encoding layer of a transformer encoder.
type family
  TEPosEncF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TEPosEncF 'T5 _ _ _ _ _ = ()
  TEPosEncF 'ByT5 gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'T5 gradient device dataType inputEmbedDim posEncDim
  TEPosEncF 'BART gradient device dataType inputEmbedDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim inputEmbedDim 'Nothing)
  TEPosEncF 'MBART gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'BART gradient device dataType inputEmbedDim posEncDim
  TEPosEncF 'Pegasus gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'BART gradient device dataType inputEmbedDim posEncDim
  TEPosEncF 'BERT gradient device dataType inputEmbedDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim inputEmbedDim 'Nothing)
  TEPosEncF 'RoBERTa gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'BERT gradient device dataType inputEmbedDim posEncDim

-- | Specifies the relative positional encoding layer of a transformer encoder.
type family
  TERelPosEncF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TERelPosEncF 'T5 gradient device dataType headDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim headDim 'Nothing)
  TERelPosEncF 'ByT5 gradient device dataType headDim posEncDim = TERelPosEncF 'T5 gradient device dataType headDim posEncDim
  TERelPosEncF 'BART _ _ _ _ _ = ()
  TERelPosEncF 'MBART gradient device dataType headDim posEncDim = TERelPosEncF 'BART gradient device dataType headDim posEncDim
  TERelPosEncF 'Pegasus gradient device dataType headDim posEncDim = TERelPosEncF 'BART gradient device dataType headDim posEncDim
  TERelPosEncF 'BERT _ _ _ _ _ = ()
  TERelPosEncF 'RoBERTa gradient device dataType headDim posEncDim = TERelPosEncF 'BERT gradient device dataType headDim posEncDim

-- | Specifies the initial layer normalization layer of a transformer encoder.
type family
  TEInitialLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TEInitialLayerNormF 'T5 _ _ _ _ = ()
  TEInitialLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TEInitialLayerNormF 'T5 gradient device dataType inputEmbedDim
  TEInitialLayerNormF 'BART gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
  TEInitialLayerNormF 'MBART gradient device dataType inputEmbedDim = TEInitialLayerNormF 'BART gradient device dataType inputEmbedDim
  TEInitialLayerNormF 'Pegasus _ _ _ _ = ()
  TEInitialLayerNormF 'BERT gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
  TEInitialLayerNormF 'RoBERTa gradient device dataType inputEmbedDim = TEInitialLayerNormF 'BERT gradient device dataType inputEmbedDim

-- | Specifies the initial dropout layer of a transformer encoder.
type family
  TEInitialDropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  TEInitialDropoutF 'T5 'WithDropout = Dropout
  TEInitialDropoutF 'ByT5 'WithDropout = Dropout
  TEInitialDropoutF 'BART 'WithDropout = Dropout
  TEInitialDropoutF 'MBART 'WithDropout = Dropout
  TEInitialDropoutF 'Pegasus 'WithDropout = Dropout
  TEInitialDropoutF 'BERT 'WithDropout = Dropout
  TEInitialDropoutF 'RoBERTa 'WithDropout = Dropout
  TEInitialDropoutF _ 'WithoutDropout = ()

-- | Specifies the transformer block stack of a transformer encoder.
type family
  TEStackF
    (style :: TransformerStyle)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  TEStackF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim hasDropout =
    NamedModel (EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim hasDropout)

-- | Specifies the final layer normalization layer of a transformer encoder.
type family
  TEFinalLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TEFinalLayerNormF 'T5 gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[inputEmbedDim]))
  TEFinalLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TEFinalLayerNormF 'T5 gradient device dataType inputEmbedDim
  TEFinalLayerNormF 'BART _ _ _ _ = ()
  TEFinalLayerNormF 'MBART gradient device dataType inputEmbedDim = TEFinalLayerNormF 'BART gradient device dataType inputEmbedDim
  TEFinalLayerNormF 'Pegasus gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
  TEFinalLayerNormF 'BERT _ _ _ _ = ()
  TEFinalLayerNormF 'RoBERTa gradient device dataType inputEmbedDim = TEFinalLayerNormF 'BERT gradient device dataType inputEmbedDim

-- | Specifies the final dropout layer of a transformer encoder.
type family
  TEFinalDropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  TEFinalDropoutF 'T5 'WithDropout = Dropout
  TEFinalDropoutF 'ByT5 'WithDropout = Dropout
  TEFinalDropoutF 'BART _ = ()
  TEFinalDropoutF 'MBART _ = ()
  TEFinalDropoutF 'Pegasus _ = ()
  TEFinalDropoutF 'BERT _ = ()
  TEFinalDropoutF 'RoBERTa _ = ()
  TEFinalDropoutF _ 'WithoutDropout = ()

-- | Specifies the parameters of a transformer in an encoder configuration.
--
-- - @style@: the style of the transformer stack, e.g. 'ST5', 'SByT5', etc.
-- - @gradient@: whether to compute the gradient of the stack's parameters.
-- - @device@: the computational device on which the stack is allocated.
-- - @dataType@: the data type of the stack's parameters.
-- - @headDim@: the dimension of all transformer heads in the stack.
-- - @headEmbedDim@: the dimension of the transformer head embeddings.
-- - @embedDim@: the dimension of the transformer embeddings.
-- - @inputEmbedDim@: the dimension of the transformer query embeddings.
-- - @ffnDim@: the dimension of the feed-forward network.
-- - @posEncDim@: the dimension of the positional encoding.
-- - @dropoutP@: the dropout rate.
-- - @eps@: the epsilon value for numerical stability of the layer normalization.
transformerEncoderSpec ::
  forall style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout.
  STransformerStyle style ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim inputEmbedDim ->
  SDim ffnDim ->
  SDim posEncDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (TransformerEncoderF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)
transformerEncoderSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (posEncDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (TransformerEncoderF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        hasDropout)
transformerEncoderSpec STransformerStyle style
style SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let posEncSpec :: STransformerStyle style
-> ModelSpec
     (TEPosEncF style gradient device dataType inputEmbedDim posEncDim)
posEncSpec STransformerStyle style
ST5 = ()
      posEncSpec STransformerStyle style
SByT5 = ()
      posEncSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  inputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  inputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  inputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.position_embeddings." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  inputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.position_embeddings." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  inputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      relPosEncSpec :: STransformerStyle style
-> ModelSpec
     (TERelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  headDim
  'Nothing
relPosEncSpec'
      relPosEncSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  headDim
  'Nothing
relPosEncSpec'
      relPosEncSpec STransformerStyle style
SBART = ()
      relPosEncSpec STransformerStyle style
SMBART = ()
      relPosEncSpec STransformerStyle style
SPegasus = ()
      relPosEncSpec STransformerStyle style
SBERT = ()
      relPosEncSpec STransformerStyle style
SRoBERTa = ()
      relPosEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      initialLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (TEInitialLayerNormF style gradient device dataType inputEmbedDim)
initialLayerNormSpec STransformerStyle style
ST5 = ()
      initialLayerNormSpec STransformerStyle style
SByT5 = ()
      initialLayerNormSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      initialLayerNormSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      initialLayerNormSpec STransformerStyle style
SPegasus = ()
      initialLayerNormSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.LayerNorm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      initialLayerNormSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.LayerNorm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      initialLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      initialDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
      stackSpec :: STransformerStyle style
-> NamedModel
     (GTransformerStack
        (VectorSpec
           numLayers
           (GTransformerBlock
              (NamedModel
                 (GSelfAttention
                    (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                    (NamedModel
                       (GMultiHeadAttention
                          headDim
                          headEmbedDim
                          embedDim
                          (QInProjF style gradient device dataType inputEmbedDim embedDim)
                          (KInProjF style gradient device dataType inputEmbedDim embedDim)
                          (VInProjF style gradient device dataType inputEmbedDim embedDim)
                          (OutProjF style gradient device dataType embedDim inputEmbedDim)
                          (DropoutF style hasDropout)))
                    (SADropoutF style hasDropout)
                    (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
              ()
              (NamedModel
                 (GTransformerFeedForwardNetwork
                    (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                    (FFNInputTransformationF
                       style gradient device dataType inputEmbedDim ffnDim)
                    (FFNActivationF style)
                    (FFNActivationDropoutF style hasDropout)
                    (FFNOutputProjectionF
                       style gradient device dataType inputEmbedDim ffnDim)
                    (FFNOutputDropoutF style hasDropout)
                    (FFNOutputLayerNormF
                       style gradient device dataType inputEmbedDim))))))
stackSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'T5
ST5
      stackSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'ByT5
SByT5
      stackSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'BART
SBART
      stackSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'MBART
SMBART
      stackSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'Pegasus
SPegasus
      stackSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder.layer." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'BERT
SBERT
      stackSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder.layer." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'RoBERTa
SRoBERTa
      stackSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      finalLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (TEFinalLayerNormF style gradient device dataType inputEmbedDim)
finalLayerNormSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
      finalLayerNormSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
      finalLayerNormSpec STransformerStyle style
SBART = ()
      finalLayerNormSpec STransformerStyle style
SMBART = ()
      finalLayerNormSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      finalLayerNormSpec STransformerStyle style
SBERT = ()
      finalLayerNormSpec STransformerStyle style
SRoBERTa = ()
      finalLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      finalDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
      finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
      finalDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
   in forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
posEnc
-> relPosEnc
-> initialLayerNorm
-> initialDropout
-> stack
-> finalLayerNorm
-> finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
GTransformer
        (STransformerStyle style
-> ModelSpec
     (TEPosEncF style gradient device dataType inputEmbedDim posEncDim)
posEncSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (TERelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (TEInitialLayerNormF style gradient device dataType inputEmbedDim)
initialLayerNormSpec STransformerStyle style
style)
        (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
        (STransformerStyle style
-> NamedModel
     (GTransformerStack
        (VectorSpec
           numLayers
           (GTransformerBlock
              (NamedModel
                 (GSelfAttention
                    (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                    (NamedModel
                       (GMultiHeadAttention
                          headDim
                          headEmbedDim
                          embedDim
                          (QInProjF style gradient device dataType inputEmbedDim embedDim)
                          (KInProjF style gradient device dataType inputEmbedDim embedDim)
                          (VInProjF style gradient device dataType inputEmbedDim embedDim)
                          (OutProjF style gradient device dataType embedDim inputEmbedDim)
                          (DropoutF style hasDropout)))
                    (SADropoutF style hasDropout)
                    (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
              ()
              (NamedModel
                 (GTransformerFeedForwardNetwork
                    (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                    (FFNInputTransformationF
                       style gradient device dataType inputEmbedDim ffnDim)
                    (FFNActivationF style)
                    (FFNActivationDropoutF style hasDropout)
                    (FFNOutputProjectionF
                       style gradient device dataType inputEmbedDim ffnDim)
                    (FFNOutputDropoutF style hasDropout)
                    (FFNOutputLayerNormF
                       style gradient device dataType inputEmbedDim))))))
stackSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (TEFinalLayerNormF style gradient device dataType inputEmbedDim)
finalLayerNormSpec STransformerStyle style
style)
        (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
  where
    stackSpec' :: _
    stackSpec' :: STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF style gradient device dataType inputEmbedDim embedDim)
                       (KInProjF style gradient device dataType inputEmbedDim embedDim)
                       (VInProjF style gradient device dataType inputEmbedDim embedDim)
                       (OutProjF style gradient device dataType embedDim inputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
           ()
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType inputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (queryEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (EncoderStackF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        ffnDim
        hasDropout)
encoderStackSpec STransformerStyle style
style' SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    layerNormSpec' :: _
    layerNormSpec' :: SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias hasBias
hasBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Natural)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias hasBias
hasBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Natural)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
inputEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
    relPosEncSpec' :: EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  headDim
  'Nothing
relPosEncSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim headDim
headDim forall a. SMaybe 'Nothing
SNothing
    posEncSpec' :: EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  inputEmbedDim
  'Nothing
posEncSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing

type family
  TransformerDecoderF
    (style :: TransformerStyle)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (decoderInputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (encoderOutputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  TransformerDecoderF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim posEncDim hasDropout =
    GTransformer
      (TDPosEncF style gradient device dataType decoderInputEmbedDim posEncDim)
      (TDRelPosEncF style gradient device dataType headDim posEncDim)
      (TDInitialLayerNormF style gradient device dataType decoderInputEmbedDim)
      (TDInitialDropoutF style hasDropout)
      (TDStackF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim hasDropout)
      (TDFinalLayerNormF style gradient device dataType decoderInputEmbedDim)
      (TDFinalDropoutF style hasDropout)

-- | Specifies the absolute positional encoding layer of a transformer decoder.
type family
  TDPosEncF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TDPosEncF 'T5 _ _ _ _ _ = ()
  TDPosEncF 'ByT5 gradient device dataType inputEmbedDim posEncDim = TDPosEncF 'T5 gradient device dataType inputEmbedDim posEncDim
  TDPosEncF 'BART gradient device dataType inputEmbedDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim inputEmbedDim 'Nothing)
  TDPosEncF 'MBART gradient device dataType inputEmbedDim posEncDim = TDPosEncF 'BART gradient device dataType inputEmbedDim posEncDim
  TDPosEncF 'Pegasus gradient device dataType inputEmbedDim posEncDim = TDPosEncF 'BART gradient device dataType inputEmbedDim posEncDim

-- | Specifies the relative positional encoding layer of a transformer decoder.
type family
  TDRelPosEncF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TDRelPosEncF 'T5 gradient device dataType headDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim headDim 'Nothing)
  TDRelPosEncF 'ByT5 gradient device dataType headDim posEncDim = TDRelPosEncF 'T5 gradient device dataType headDim posEncDim
  TDRelPosEncF 'BART _ _ _ _ _ = ()
  TDRelPosEncF 'MBART gradient device dataType headDim posEncDim = TDRelPosEncF 'BART gradient device dataType headDim posEncDim
  TDRelPosEncF 'Pegasus gradient device dataType headDim posEncDim = TDRelPosEncF 'BART gradient device dataType headDim posEncDim

-- | Specifies the initial layer normalization layer of a transformer decoder.
type family
  TDInitialLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TDInitialLayerNormF 'T5 _ _ _ _ = ()
  TDInitialLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TDInitialLayerNormF 'T5 gradient device dataType inputEmbedDim
  TDInitialLayerNormF 'BART gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
  TDInitialLayerNormF 'MBART gradient device dataType inputEmbedDim = TDInitialLayerNormF 'BART gradient device dataType inputEmbedDim
  TDInitialLayerNormF 'Pegasus _ _ _ _ = ()

-- | Specifies the initial dropout layer of a transformer decoder.
type family
  TDInitialDropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  TDInitialDropoutF 'T5 'WithDropout = Dropout
  TDInitialDropoutF 'ByT5 'WithDropout = Dropout
  TDInitialDropoutF 'BART 'WithDropout = Dropout
  TDInitialDropoutF 'MBART 'WithDropout = Dropout
  TDInitialDropoutF 'Pegasus 'WithDropout = Dropout
  TDInitialDropoutF _ 'WithoutDropout = ()

-- | Specifies the transformer block stack of a transformer decoder.
type family
  TDStackF
    (style :: TransformerStyle)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (decoderInputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (encoderOutputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  TDStackF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim hasDropout =
    NamedModel (DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim hasDropout)

-- | Specifies the final layer normalization layer of a transformer decoder.
type family
  TDFinalLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  TDFinalLayerNormF 'T5 gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[inputEmbedDim]))
  TDFinalLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TDFinalLayerNormF 'T5 gradient device dataType inputEmbedDim
  TDFinalLayerNormF 'BART _ _ _ _ = ()
  TDFinalLayerNormF 'MBART gradient device dataType inputEmbedDim = TDFinalLayerNormF 'BART gradient device dataType inputEmbedDim
  TDFinalLayerNormF 'Pegasus gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))

-- | Specifies the final dropout layer of a transformer decoder.
type family
  TDFinalDropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  TDFinalDropoutF 'T5 'WithDropout = Dropout
  TDFinalDropoutF 'ByT5 'WithDropout = Dropout
  TDFinalDropoutF 'BART _ = ()
  TDFinalDropoutF 'MBART _ = ()
  TDFinalDropoutF 'Pegasus _ = ()
  TDFinalDropoutF _ 'WithoutDropout = ()

-- | Specifies the parameters of a transformer in a decoder configuration.
--
-- - @style@: the style of the transformer stack, e.g. 'ST5', 'SByT5', etc.
-- - @gradient@: whether to compute the gradient of the stack's parameters.
-- - @device@: the computational device on which the stack is allocated.
-- - @dataType@: the data type of the stack's parameters.
-- - @headDim@: the dimension of all transformer heads in the stack.
-- - @headEmbedDim@: the dimension of the transformer head embeddings.
-- - @embedDim@: the dimension of the transformer embeddings.
-- - @decoderInputEmbedDim@: the dimension of the decoder input embeddings.
-- - @encoderOutputEmbedDim@: the dimension of the encoder output embeddings.
-- - @ffnDim@: the dimension of the feed-forward network.
-- - @posEncDim@: the dimension of the positional encoding.
-- - @dropoutP@: the dropout rate.
-- - @eps@: the epsilon value for numerical stability of the layer normalization.
transformerDecoderSpec ::
  forall style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim posEncDim hasDropout.
  STransformerStyle style ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim decoderInputEmbedDim ->
  SDim encoderOutputEmbedDim ->
  SDim ffnDim ->
  SDim posEncDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (TransformerDecoderF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim posEncDim hasDropout)
transformerDecoderSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (decoderInputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (encoderOutputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (posEncDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim decoderInputEmbedDim
-> SDim encoderOutputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (TransformerDecoderF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        decoderInputEmbedDim
        encoderOutputEmbedDim
        ffnDim
        posEncDim
        hasDropout)
transformerDecoderSpec STransformerStyle style
style SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim decoderInputEmbedDim
decoderInputEmbedDim SDim encoderOutputEmbedDim
encoderOutputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let posEncSpec :: STransformerStyle style
-> ModelSpec
     (TDPosEncF
        style gradient device dataType decoderInputEmbedDim posEncDim)
posEncSpec STransformerStyle style
ST5 = ()
      posEncSpec STransformerStyle style
SByT5 = ()
      posEncSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  decoderInputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  decoderInputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  decoderInputEmbedDim
  'Nothing
posEncSpec'
      posEncSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      posEncSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      posEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      relPosEncSpec :: STransformerStyle style
-> ModelSpec
     (TDRelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  headDim
  'Nothing
relPosEncSpec'
      relPosEncSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  headDim
  'Nothing
relPosEncSpec'
      relPosEncSpec STransformerStyle style
SBART = ()
      relPosEncSpec STransformerStyle style
SMBART = ()
      relPosEncSpec STransformerStyle style
SPegasus = ()
      relPosEncSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      relPosEncSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      relPosEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      initialLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (TDInitialLayerNormF
        style gradient device dataType decoderInputEmbedDim)
initialLayerNormSpec STransformerStyle style
ST5 = ()
      initialLayerNormSpec STransformerStyle style
SByT5 = ()
      initialLayerNormSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      initialLayerNormSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      initialLayerNormSpec STransformerStyle style
SPegasus = ()
      initialLayerNormSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      initialLayerNormSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      initialLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      initialDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithoutDropout = ()
      initialDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
      initialDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
      initialDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
      stackSpec :: STransformerStyle style
-> NamedModel
     (GTransformerStack
        (VectorSpec
           numLayers
           (GTransformerBlock
              (NamedModel
                 (GSelfAttention
                    (SAInitialLayerNormF
                       style gradient device dataType decoderInputEmbedDim)
                    (NamedModel
                       (GMultiHeadAttention
                          headDim
                          headEmbedDim
                          embedDim
                          (QInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (KInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (VInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (OutProjF
                             style gradient device dataType embedDim decoderInputEmbedDim)
                          (DropoutF style hasDropout)))
                    (SADropoutF style hasDropout)
                    (SAFinalLayerNormF
                       style gradient device dataType decoderInputEmbedDim)))
              (NamedModel
                 (GCrossAttention
                    (CAInitialLayerNormF
                       style gradient device dataType decoderInputEmbedDim)
                    (NamedModel
                       (GMultiHeadAttention
                          headDim
                          headEmbedDim
                          embedDim
                          (QInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (KInProjF
                             style gradient device dataType encoderOutputEmbedDim embedDim)
                          (VInProjF
                             style gradient device dataType encoderOutputEmbedDim embedDim)
                          (OutProjF
                             style gradient device dataType embedDim decoderInputEmbedDim)
                          (DropoutF style hasDropout)))
                    (CADropoutF style hasDropout)
                    (CAFinalLayerNormF
                       style gradient device dataType decoderInputEmbedDim)))
              (NamedModel
                 (GTransformerFeedForwardNetwork
                    (FFNInputLayerNormF
                       style gradient device dataType decoderInputEmbedDim)
                    (FFNInputTransformationF
                       style gradient device dataType decoderInputEmbedDim ffnDim)
                    (FFNActivationF style)
                    (FFNActivationDropoutF style hasDropout)
                    (FFNOutputProjectionF
                       style gradient device dataType decoderInputEmbedDim ffnDim)
                    (FFNOutputDropoutF style hasDropout)
                    (FFNOutputLayerNormF
                       style gradient device dataType decoderInputEmbedDim))))))
stackSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GCrossAttention
                 (CAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (CADropoutF style hasDropout)
                 (CAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'T5
ST5
      stackSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GCrossAttention
                 (CAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (CADropoutF style hasDropout)
                 (CAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'ByT5
SByT5
      stackSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GCrossAttention
                 (CAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (CADropoutF style hasDropout)
                 (CAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'BART
SBART
      stackSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GCrossAttention
                 (CAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (CADropoutF style hasDropout)
                 (CAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'MBART
SMBART
      stackSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GCrossAttention
                 (CAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (CADropoutF style hasDropout)
                 (CAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'Pegasus
SPegasus
      stackSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      stackSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      stackSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      finalLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (TDFinalLayerNormF
        style gradient device dataType decoderInputEmbedDim)
finalLayerNormSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
      finalLayerNormSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
      finalLayerNormSpec STransformerStyle style
SBART = ()
      finalLayerNormSpec STransformerStyle style
SMBART = ()
      finalLayerNormSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
      finalLayerNormSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      finalLayerNormSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      finalLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      finalDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
      finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
      finalDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
_ = ()
      finalDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
      finalDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
      finalDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
   in forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
posEnc
-> relPosEnc
-> initialLayerNorm
-> initialDropout
-> stack
-> finalLayerNorm
-> finalDropout
-> GTransformer
     posEnc
     relPosEnc
     initialLayerNorm
     initialDropout
     stack
     finalLayerNorm
     finalDropout
GTransformer
        (STransformerStyle style
-> ModelSpec
     (TDPosEncF
        style gradient device dataType decoderInputEmbedDim posEncDim)
posEncSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (TDRelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (TDInitialLayerNormF
        style gradient device dataType decoderInputEmbedDim)
initialLayerNormSpec STransformerStyle style
style)
        (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
        (STransformerStyle style
-> NamedModel
     (GTransformerStack
        (VectorSpec
           numLayers
           (GTransformerBlock
              (NamedModel
                 (GSelfAttention
                    (SAInitialLayerNormF
                       style gradient device dataType decoderInputEmbedDim)
                    (NamedModel
                       (GMultiHeadAttention
                          headDim
                          headEmbedDim
                          embedDim
                          (QInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (KInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (VInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (OutProjF
                             style gradient device dataType embedDim decoderInputEmbedDim)
                          (DropoutF style hasDropout)))
                    (SADropoutF style hasDropout)
                    (SAFinalLayerNormF
                       style gradient device dataType decoderInputEmbedDim)))
              (NamedModel
                 (GCrossAttention
                    (CAInitialLayerNormF
                       style gradient device dataType decoderInputEmbedDim)
                    (NamedModel
                       (GMultiHeadAttention
                          headDim
                          headEmbedDim
                          embedDim
                          (QInProjF
                             style gradient device dataType decoderInputEmbedDim embedDim)
                          (KInProjF
                             style gradient device dataType encoderOutputEmbedDim embedDim)
                          (VInProjF
                             style gradient device dataType encoderOutputEmbedDim embedDim)
                          (OutProjF
                             style gradient device dataType embedDim decoderInputEmbedDim)
                          (DropoutF style hasDropout)))
                    (CADropoutF style hasDropout)
                    (CAFinalLayerNormF
                       style gradient device dataType decoderInputEmbedDim)))
              (NamedModel
                 (GTransformerFeedForwardNetwork
                    (FFNInputLayerNormF
                       style gradient device dataType decoderInputEmbedDim)
                    (FFNInputTransformationF
                       style gradient device dataType decoderInputEmbedDim ffnDim)
                    (FFNActivationF style)
                    (FFNActivationDropoutF style hasDropout)
                    (FFNOutputProjectionF
                       style gradient device dataType decoderInputEmbedDim ffnDim)
                    (FFNOutputDropoutF style hasDropout)
                    (FFNOutputLayerNormF
                       style gradient device dataType decoderInputEmbedDim))))))
stackSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (TDFinalLayerNormF
        style gradient device dataType decoderInputEmbedDim)
finalLayerNormSpec STransformerStyle style
style)
        (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
  where
    stackSpec' :: _
    stackSpec' :: STransformerStyle style
-> GTransformerStack
     (VectorSpec
        numLayers
        (GTransformerBlock
           (NamedModel
              (GSelfAttention
                 (SAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (SADropoutF style hasDropout)
                 (SAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GCrossAttention
                 (CAInitialLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (NamedModel
                    (GMultiHeadAttention
                       headDim
                       headEmbedDim
                       embedDim
                       (QInProjF
                          style gradient device dataType decoderInputEmbedDim embedDim)
                       (KInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (VInProjF
                          style gradient device dataType encoderOutputEmbedDim embedDim)
                       (OutProjF
                          style gradient device dataType embedDim decoderInputEmbedDim)
                       (DropoutF style hasDropout)))
                 (CADropoutF style hasDropout)
                 (CAFinalLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))
           (NamedModel
              (GTransformerFeedForwardNetwork
                 (FFNInputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)
                 (FFNInputTransformationF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNActivationF style)
                 (FFNActivationDropoutF style hasDropout)
                 (FFNOutputProjectionF
                    style gradient device dataType decoderInputEmbedDim ffnDim)
                 (FFNOutputDropoutF style hasDropout)
                 (FFNOutputLayerNormF
                    style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (queryEmbedDim :: Dim (Name Symbol) (Size Natural))
       (keyEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (DecoderStackF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        ffnDim
        hasDropout)
decoderStackSpec STransformerStyle style
style' SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim decoderInputEmbedDim
decoderInputEmbedDim SDim encoderOutputEmbedDim
encoderOutputEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    layerNormSpec' :: _
    layerNormSpec' :: SHasBias hasBias
-> LayerNormSpec
     hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias hasBias
hasBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Natural)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias hasBias
hasBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Natural)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim decoderInputEmbedDim
decoderInputEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
    relPosEncSpec' :: EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  headDim
  'Nothing
relPosEncSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim headDim
headDim forall a. SMaybe 'Nothing
SNothing
    posEncSpec' :: EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  posEncDim
  decoderInputEmbedDim
  'Nothing
posEncSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim decoderInputEmbedDim
decoderInputEmbedDim forall a. SMaybe 'Nothing
SNothing

instance
  ( HasInitialize posEnc generatorDevice posEnc' generatorDevice0,
    HasInitialize relPosEnc generatorDevice0 relPosEnc' generatorDevice1,
    HasInitialize initialLayerNorm generatorDevice1 initialLayerNorm' generatorDevice2,
    HasInitialize initialDropout generatorDevice2 initialDropout' generatorDevice3,
    HasInitialize stack generatorDevice3 stack' generatorDevice4,
    HasInitialize finalLayerNorm generatorDevice4 finalLayerNorm' generatorDevice5,
    HasInitialize finalDropout generatorDevice5 finalDropout' generatorOutputDevice
  ) =>
  HasInitialize
    (GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
    generatorDevice
    (GTransformer posEnc' relPosEnc' initialLayerNorm' initialDropout' stack' finalLayerNorm' finalDropout')
    generatorOutputDevice

instance
  ( HasStateDict posEnc,
    HasStateDict relPosEnc,
    HasStateDict initialLayerNorm,
    HasStateDict initialDropout,
    HasStateDict stack,
    HasStateDict finalLayerNorm,
    HasStateDict finalDropout
  ) =>
  HasStateDict (GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)

-- | 'HasForward' instance for 'GTransformer' in an encoder configuration
-- with absolute positional encoding rather than relative positional encoding.
--
-- @
-- ┌───────┐  ┌─────┐  ┌───────────────┐
-- │ input │  │ pos │  │ attentionMask │
-- └───┬───┘  └─────┘  └───────┬───────┘
--     │         │             │
--     │         ▼             │
--     │      tPosEnc          │
--     │         │             │
--     └──►add◄──┘             │
--          │                  │
--          ▼                  │
-- (tInitialLayerNorm)         │
--          ▼                  ▼
--  (tInitialDropout)     unsqueeze
--          ▼                  │
--       tStack◄───────────────┘
--          ▼
--  (tFinalLayerNorm)
--          ▼
--   (tFinalDropout)
--          │
--          ▼
--     ┌────────┐
--     │ output │
--     └────────┘
-- @
instance
  ( HasForward
      posEnc
      (Tensor posGradient posLayout posDevice posDataType posShape)
      generatorDevice
      (Tensor posEncGradient posEncLayout posEncDevice posEncDataType posEncShape)
      generatorDevice0,
    HasForward
      initialLayerNorm
      ( Tensor
          (inputGradient <|> posEncGradient)
          (inputLayout <+> posEncLayout)
          (inputDevice <+> posEncDevice)
          (inputDataType <+> posEncDataType)
          (BroadcastShapesF inputShape posEncShape)
      )
      generatorDevice0
      tensor1
      generatorDevice1,
    Catch (BroadcastShapesF inputShape posEncShape),
    HasForward
      initialDropout
      tensor1
      generatorDevice1
      tensor2
      generatorDevice2,
    HasForward
      stack
      ( tensor2,
        Tensor
          attentionMaskGradient
          attentionMaskLayout
          attentionMaskDevice
          attentionMaskDataType
          (UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape)
      )
      generatorDevice2
      tensor3
      generatorDevice3,
    Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape),
    HasForward
      finalLayerNorm
      tensor3
      generatorDevice3
      tensor4
      generatorDevice4,
    HasForward
      finalDropout
      tensor4
      generatorDevice4
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformer posEnc () initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
    ( Tensor inputGradient inputLayout inputDevice inputDataType inputShape,
      Tensor posGradient posLayout posDevice posDataType posShape,
      Tensor attentionMaskGradient attentionMaskLayout attentionMaskDevice attentionMaskDataType attentionMaskShape
    )
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
  posEnc
  ()
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> (Tensor
      inputGradient inputLayout inputDevice inputDataType inputShape,
    Tensor posGradient posLayout posDevice posDataType posShape,
    Tensor
      attentionMaskGradient
      attentionMaskLayout
      attentionMaskDevice
      attentionMaskDataType
      attentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {posEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: ()
tPosEnc :: posEnc
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> posEnc
..} (Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input, Tensor posGradient posLayout posDevice posDataType posShape
pos, Tensor
  attentionMaskGradient
  attentionMaskLayout
  attentionMaskDevice
  attentionMaskDataType
  attentionMaskShape
attentionMask) =
    let attentionBias :: IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     attentionMaskGradient
     attentionMaskLayout
     attentionMaskDevice
     attentionMaskDataType
     (UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape))
attentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
  attentionMaskGradient
  attentionMaskLayout
  attentionMaskDevice
  attentionMaskDataType
  attentionMaskShape
attentionMask
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor posGradient posLayout posDevice posDataType posShape
pos
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward posEnc
tPosEnc
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType)
       (device' :: Device (DeviceType Natural))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
       (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`add`)
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \tensor2
input' ->
                     IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     attentionMaskGradient
     attentionMaskLayout
     attentionMaskDevice
     attentionMaskDataType
     (UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape))
attentionBias
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
  attentionMaskGradient
  attentionMaskLayout
  attentionMaskDevice
  attentionMaskDataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape)
attentionBias' ->
                                forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward stack
tStack (tensor2
input', Tensor
  attentionMaskGradient
  attentionMaskLayout
  attentionMaskDevice
  attentionMaskDataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape)
attentionBias')
                            )
                 )
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout

-- | 'HasForward' instance for 'GTransformer' in an encoder configuration
-- with relative positional encoding rather than absolute positional encoding.
--
-- @
--      ┌───────┐  ┌────────┐  ┌───────────────┐
--      │ input │  │ relPos │  │ attentionMask │
--      └───┬───┘  └───┬────┘  └───────┬───────┘
--          │          │               │
--          │          ▼               │
--          │     tRelPosEnc           │
--          │          ▼               │
--          │      transpose           │
--          │          ▼               ▼
--          │      transpose       unsqueeze
--          ▼          │               │
-- (tInitialLayerNorm) │               │
--          ▼          └─────►add◄─────┘
--  (tInitialDropout)          │
--          ▼                  │
--       tStack◄───────────────┘
--          ▼
--  (tFinalLayerNorm)
--          ▼
--   (tFinalDropout)
--          │
--          ▼
--     ┌────────┐
--     │ output │
--     └────────┘
-- @
instance
  ( HasForward
      initialLayerNorm
      (Tensor inputGradient inputLayout inputDevice inputDataType inputShape)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      initialDropout
      tensor0
      generatorDevice0
      tensor1
      generatorDevice1,
    HasForward
      relPosEnc
      (Tensor relPosGradient relPosLayout relPosDevice relPosDataType relPosShape)
      generatorDevice1
      (Tensor relPosEncGradient relPosEncLayout relPosEncDevice relPosEncDataType relPosEncShape)
      generatorDevice2,
    HasForward
      stack
      ( tensor1,
        Tensor
          (relPosEncGradient <|> attentionMaskGradient)
          (relPosEncLayout <+> attentionMaskLayout)
          (relPosEncDevice <+> attentionMaskDevice)
          (relPosEncDataType <+> attentionMaskDataType)
          (BroadcastShapesF doubleTransposedRelPosEncShape unsqueezedAttentionMaskShape)
      )
      generatorDevice2
      tensor3
      generatorDevice3,
    transposedRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 2)) ('SelectDim ('ByIndex 3)) relPosEncShape,
    Catch transposedRelPosEncShape,
    doubleTransposedRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) transposedRelPosEncShape,
    Catch doubleTransposedRelPosEncShape,
    unsqueezedAttentionMaskShape ~ UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape,
    Catch unsqueezedAttentionMaskShape,
    Catch (BroadcastShapesF doubleTransposedRelPosEncShape unsqueezedAttentionMaskShape),
    HasForward
      finalLayerNorm
      tensor3
      generatorDevice3
      tensor4
      generatorDevice4,
    HasForward
      finalDropout
      tensor4
      generatorDevice4
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformer () relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
    ( Tensor inputGradient inputLayout inputDevice inputDataType inputShape,
      Tensor relPosGradient relPosLayout relPosDevice relPosDataType relPosShape,
      Tensor attentionMaskGradient attentionMaskLayout attentionMaskDevice attentionMaskDataType attentionMaskShape
    )
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
  ()
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> (Tensor
      inputGradient inputLayout inputDevice inputDataType inputShape,
    Tensor
      relPosGradient
      relPosLayout
      relPosDevice
      relPosDataType
      relPosShape,
    Tensor
      attentionMaskGradient
      attentionMaskLayout
      attentionMaskDevice
      attentionMaskDataType
      attentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {initialLayerNorm
initialDropout
relPosEnc
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: relPosEnc
tPosEnc :: ()
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> posEnc
..} (Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input, Tensor
  relPosGradient relPosLayout relPosDevice relPosDataType relPosShape
relPos, Tensor
  attentionMaskGradient
  attentionMaskLayout
  attentionMaskDevice
  attentionMaskDataType
  attentionMaskShape
attentionMask) =
    let relPosBias :: IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     relPosEncGradient
     relPosEncLayout
     relPosEncDevice
     relPosEncDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           relPosEncShape)))
relPosBias =
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  relPosGradient relPosLayout relPosDevice relPosDataType relPosShape
relPos
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward relPosEnc
tRelPosEnc
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
       (selectDim1 :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @2) (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @3)
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
       (selectDim1 :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @1) (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @2)
        attentionBias :: IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     (Or
        (Gradient RequiresGradient)
        relPosEncGradient
        attentionMaskGradient)
     (Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
     (Unify
        (Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
     (Unify (DataType DType) relPosEncDataType attentionMaskDataType)
     (BroadcastShapesF
        (TransposeF
           ('SelectDim ('ByIndex 1))
           ('SelectDim ('ByIndex 2))
           (TransposeF
              ('SelectDim ('ByIndex 2))
              ('SelectDim ('ByIndex 3))
              relPosEncShape))
        unsqueezedAttentionMaskShape))
attentionBias =
          IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     relPosEncGradient
     relPosEncLayout
     relPosEncDevice
     relPosEncDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           relPosEncShape)))
relPosBias
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (selectDim :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @1) Tensor
  attentionMaskGradient
  attentionMaskLayout
  attentionMaskDevice
  attentionMaskDataType
  attentionMaskShape
attentionMask forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType)
       (device' :: Device (DeviceType Natural))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
       (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
add
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\tensor1
input' -> IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     (Or
        (Gradient RequiresGradient)
        relPosEncGradient
        attentionMaskGradient)
     (Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
     (Unify
        (Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
     (Unify (DataType DType) relPosEncDataType attentionMaskDataType)
     (BroadcastShapesF
        (TransposeF
           ('SelectDim ('ByIndex 1))
           ('SelectDim ('ByIndex 2))
           (TransposeF
              ('SelectDim ('ByIndex 2))
              ('SelectDim ('ByIndex 3))
              relPosEncShape))
        unsqueezedAttentionMaskShape))
attentionBias forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor
  (Or
     (Gradient RequiresGradient)
     relPosEncGradient
     attentionMaskGradient)
  (Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
  (Unify
     (Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
  (Unify (DataType DType) relPosEncDataType attentionMaskDataType)
  (BroadcastShapesF
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           relPosEncShape))
     unsqueezedAttentionMaskShape)
attentionBias' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward stack
tStack (tensor1
input', Tensor
  (Or
     (Gradient RequiresGradient)
     relPosEncGradient
     attentionMaskGradient)
  (Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
  (Unify
     (Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
  (Unify (DataType DType) relPosEncDataType attentionMaskDataType)
  (BroadcastShapesF
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           relPosEncShape))
     unsqueezedAttentionMaskShape)
attentionBias')))
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout

-- | 'HasForward' instance for 'GTransformer' in a decoder configuration
-- with absolute positional encoding rather than relative positional encoding.
--
-- @
-- ┌──────────────┐  ┌────────────┐  ┌───────────────┐  ┌──────────────────────┐  ┌────────────────────┐
-- │ decoderInput │  │ decoderPos │  │ encoderOutput │  │ decoderAttentionMask │  │ crossAttentionMask │
-- └──────┬───────┘  └──────┬─────┘  └───────┬───────┘  └──────────┬───────────┘  └──────────┬─────────┘
--        │                 │                │                     │                         │
--        │                 ▼                │                     │                         │
--        │             tdPosEnc             │                     │                         │
--        │                 │                │                     │                         │
--        └──────►add◄──────┘                │                     │                         │
--                 │                         │                     │                         │
--                 ▼                         │                     │                         │
--        (tInitialLayerNorm)                │                     │                         │
--                 ▼                         │                     ▼                         ▼
--         (tInitialDropout)                 │                 unsqueeze                 unsqueeze
--                 ▼                         │                     │                         │
--              tStack◄──────────────────────┘◄────────────────────┘◄────────────────────────┘
--                 ▼
--         (tFinalLayerNorm)
--                 ▼
--          (tFinalDropout)
--                 │
--                 ▼
--            ┌────────┐
--            │ output │
--            └────────┘
-- @
instance
  ( HasForward
      posEnc
      (Tensor decoderPosGradient decoderPosLayout decoderPosDevice decoderPosDataType decoderPosShape)
      generatorDevice
      (Tensor decoderPosEncGradient decoderPosEncLayout decoderPosEncDevice decoderPosEncDataType decoderPosEncShape)
      generatorDevice0,
    HasForward
      initialLayerNorm
      ( Tensor
          (decoderInputGradient <|> decoderPosEncGradient)
          (decoderInputLayout <+> decoderPosEncLayout)
          (decoderInputDevice <+> decoderPosEncDevice)
          (decoderInputDataType <+> decoderPosEncDataType)
          (BroadcastShapesF decoderInputShape decoderPosEncShape)
      )
      generatorDevice0
      tensor1
      generatorDevice1,
    HasForward
      initialDropout
      tensor1
      generatorDevice1
      tensor2
      generatorDevice2,
    HasForward
      stack
      ( tensor2,
        Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
        Tensor
          decoderAttentionMaskGradient
          decoderAttentionMaskLayout
          decoderAttentionMaskDevice
          decoderAttentionMaskDataType
          (UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape),
        Tensor
          crossAttentionMaskGradient
          crossAttentionMaskLayout
          crossAttentionMaskDevice
          crossAttentionMaskDataType
          (UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape)
      )
      generatorDevice2
      tensor3
      generatorDevice3,
    Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape),
    Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape),
    Catch (BroadcastShapesF decoderInputShape decoderPosEncShape),
    HasForward
      finalLayerNorm
      tensor3
      generatorDevice3
      tensor4
      generatorDevice4,
    HasForward
      finalDropout
      tensor4
      generatorDevice4
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformer posEnc () initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
    ( Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape,
      Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
      Tensor decoderPosGradient decoderPosLayout decoderPosDevice decoderPosDataType decoderPosShape,
      Tensor decoderAttentionMaskGradient decoderAttentionMaskLayout decoderAttentionMaskDevice decoderAttentionMaskDataType decoderAttentionMaskShape,
      Tensor crossAttentionMaskGradient crossAttentionMaskLayout crossAttentionMaskDevice crossAttentionMaskDataType crossAttentionMaskShape
    )
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
  posEnc
  ()
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> (Tensor
      decoderInputGradient
      decoderInputLayout
      decoderInputDevice
      decoderInputDataType
      decoderInputShape,
    Tensor
      encoderOutputGradient
      encoderOutputLayout
      encoderOutputDevice
      encoderOutputDataType
      encoderOutputShape,
    Tensor
      decoderPosGradient
      decoderPosLayout
      decoderPosDevice
      decoderPosDataType
      decoderPosShape,
    Tensor
      decoderAttentionMaskGradient
      decoderAttentionMaskLayout
      decoderAttentionMaskDevice
      decoderAttentionMaskDataType
      decoderAttentionMaskShape,
    Tensor
      crossAttentionMaskGradient
      crossAttentionMaskLayout
      crossAttentionMaskDevice
      crossAttentionMaskDataType
      crossAttentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {posEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: ()
tPosEnc :: posEnc
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> posEnc
..} (Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
decoderInput, Tensor
  encoderOutputGradient
  encoderOutputLayout
  encoderOutputDevice
  encoderOutputDataType
  encoderOutputShape
encoderOutput, Tensor
  decoderPosGradient
  decoderPosLayout
  decoderPosDevice
  decoderPosDataType
  decoderPosShape
decoderPos, Tensor
  decoderAttentionMaskGradient
  decoderAttentionMaskLayout
  decoderAttentionMaskDevice
  decoderAttentionMaskDataType
  decoderAttentionMaskShape
decoderAttentionMask, Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  crossAttentionMaskShape
crossAttentionMask) =
    let decoderAttentionBias :: IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     decoderAttentionMaskGradient
     decoderAttentionMaskLayout
     decoderAttentionMaskDevice
     decoderAttentionMaskDataType
     (UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape))
decoderAttentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
  decoderAttentionMaskGradient
  decoderAttentionMaskLayout
  decoderAttentionMaskDevice
  decoderAttentionMaskDataType
  decoderAttentionMaskShape
decoderAttentionMask
        crossAttentionBias :: IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     crossAttentionMaskGradient
     crossAttentionMaskLayout
     crossAttentionMaskDevice
     crossAttentionMaskDataType
     (UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape))
crossAttentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  crossAttentionMaskShape
crossAttentionMask
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  decoderPosGradient
  decoderPosLayout
  decoderPosDevice
  decoderPosDataType
  decoderPosShape
decoderPos
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward posEnc
tPosEnc
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
decoderInput forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType)
       (device' :: Device (DeviceType Natural))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
       (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`add`)
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \tensor2
decoderInput' ->
                     IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     decoderAttentionMaskGradient
     decoderAttentionMaskLayout
     decoderAttentionMaskDevice
     decoderAttentionMaskDataType
     (UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape))
decoderAttentionBias
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
  decoderAttentionMaskGradient
  decoderAttentionMaskLayout
  decoderAttentionMaskDevice
  decoderAttentionMaskDataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape)
decoderAttentionBias' ->
                                IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     crossAttentionMaskGradient
     crossAttentionMaskLayout
     crossAttentionMaskDevice
     crossAttentionMaskDataType
     (UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape))
crossAttentionBias
                                  forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape)
crossAttentionBias' ->
                                           forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
                                             forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
                                               stack
tStack
                                               ( tensor2
decoderInput',
                                                 Tensor
  encoderOutputGradient
  encoderOutputLayout
  encoderOutputDevice
  encoderOutputDataType
  encoderOutputShape
encoderOutput,
                                                 Tensor
  decoderAttentionMaskGradient
  decoderAttentionMaskLayout
  decoderAttentionMaskDevice
  decoderAttentionMaskDataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape)
decoderAttentionBias',
                                                 Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape)
crossAttentionBias'
                                               )
                                       )
                            )
                 )
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout

-- | 'HasForward' instance for 'GTransformer' in a decoder configuration
-- with relative positional encoding rather than absolute positional encoding.
--
-- @
--   ┌──────────────┐  ┌───────────────┐  ┌───────────────┐  ┌──────────────────────┐  ┌────────────────────┐
--   │ decoderInput │  │ encoderOutput │  │ decoderRelPos │  │ decoderAttentionMask │  │ crossAttentionMask │
--   └──────┬───────┘  └───────┬───────┘  └───────┬───────┘  └──────────┬───────────┘  └─────────┬──────────┘
--          │                  │                  │                     │                        │
--          │                  │                  ▼                     │                        │
--          │                  │             tdRelPosEnc                │                        │
--          │                  │                  ▼                     │                        │
--          │                  │              transpose                 │                        │
--          │                  │                  ▼                     ▼                        ▼
--          │                  │              transpose             unsqueeze                unsqueeze
--          ▼                  │                  │                     │                        │
-- (tInitialLayerNorm)         │                  │                     │                        │
--          ▼                  │                  └────────►add◄────────┘                        │
--  (tInitialDropout)          │                             │                                   │
--          ▼                  │                             │                                   │
--       tStack◄───────────────┘◄────────────────────────────┘◄──────────────────────────────────┘
--          ▼
--  (tFinalLayerNorm)
--          ▼
--   (tFinalDropout)
--          │
--          ▼
--     ┌────────┐
--     │ output │
--     └────────┘
-- @
instance
  ( HasForward
      initialLayerNorm
      (Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      initialDropout
      tensor0
      generatorDevice0
      tensor1
      generatorDevice1,
    HasForward
      relPosEnc
      (Tensor decoderRelPosGradient decoderRelPosLayout decoderRelPosDevice decoderRelPosDataType decoderRelPosShape)
      generatorDevice1
      (Tensor decoderRelPosEncGradient decoderRelPosEncLayout decoderRelPosEncDevice decoderRelPosEncDataType decoderRelPosEncShape)
      generatorDevice2,
    HasForward
      stack
      ( tensor1,
        Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
        Tensor
          (decoderRelPosEncGradient <|> decoderAttentionMaskGradient)
          (decoderRelPosEncLayout <+> decoderAttentionMaskLayout)
          (decoderRelPosEncDevice <+> decoderAttentionMaskDevice)
          (decoderRelPosEncDataType <+> decoderAttentionMaskDataType)
          (BroadcastShapesF doubleTransposedDecoderRelPosEncShape unsqueezedDecoderAttentionMaskShape),
        Tensor
          crossAttentionMaskGradient
          crossAttentionMaskLayout
          crossAttentionMaskDevice
          crossAttentionMaskDataType
          unsqueezedCrossAttentionMaskShape
      )
      generatorDevice2
      tensor3
      generatorDevice3,
    transposedDecoderRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 2)) ('SelectDim ('ByIndex 3)) decoderRelPosEncShape,
    Catch transposedDecoderRelPosEncShape,
    doubleTransposedDecoderRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) transposedDecoderRelPosEncShape,
    Catch doubleTransposedDecoderRelPosEncShape,
    unsqueezedDecoderAttentionMaskShape ~ UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape,
    Catch unsqueezedDecoderAttentionMaskShape,
    unsqueezedCrossAttentionMaskShape ~ UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape,
    Catch unsqueezedCrossAttentionMaskShape,
    Catch (BroadcastShapesF doubleTransposedDecoderRelPosEncShape unsqueezedDecoderAttentionMaskShape),
    HasForward
      finalLayerNorm
      tensor3
      generatorDevice3
      tensor4
      generatorDevice4,
    HasForward
      finalDropout
      tensor4
      generatorDevice4
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformer () relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
    ( Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape,
      Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
      Tensor decoderRelPosGradient decoderRelPosLayout decoderRelPosDevice decoderRelPosDataType decoderRelPosShape,
      Tensor decoderAttentionMaskGradient decoderAttentionMaskLayout decoderAttentionMaskDevice decoderAttentionMaskDataType decoderAttentionMaskShape,
      Tensor crossAttentionMaskGradient crossAttentionMaskLayout crossAttentionMaskDevice crossAttentionMaskDataType crossAttentionMaskShape
    )
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
  ()
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> (Tensor
      decoderInputGradient
      decoderInputLayout
      decoderInputDevice
      decoderInputDataType
      decoderInputShape,
    Tensor
      encoderOutputGradient
      encoderOutputLayout
      encoderOutputDevice
      encoderOutputDataType
      encoderOutputShape,
    Tensor
      decoderRelPosGradient
      decoderRelPosLayout
      decoderRelPosDevice
      decoderRelPosDataType
      decoderRelPosShape,
    Tensor
      decoderAttentionMaskGradient
      decoderAttentionMaskLayout
      decoderAttentionMaskDevice
      decoderAttentionMaskDataType
      decoderAttentionMaskShape,
    Tensor
      crossAttentionMaskGradient
      crossAttentionMaskLayout
      crossAttentionMaskDevice
      crossAttentionMaskDataType
      crossAttentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {initialLayerNorm
initialDropout
relPosEnc
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: relPosEnc
tPosEnc :: ()
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
       finalLayerNorm finalDropout.
GTransformer
  posEnc
  relPosEnc
  initialLayerNorm
  initialDropout
  stack
  finalLayerNorm
  finalDropout
-> posEnc
..} (Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
decoderInput, Tensor
  encoderOutputGradient
  encoderOutputLayout
  encoderOutputDevice
  encoderOutputDataType
  encoderOutputShape
encoderOutput, Tensor
  decoderRelPosGradient
  decoderRelPosLayout
  decoderRelPosDevice
  decoderRelPosDataType
  decoderRelPosShape
decoderRelPos, Tensor
  decoderAttentionMaskGradient
  decoderAttentionMaskLayout
  decoderAttentionMaskDevice
  decoderAttentionMaskDataType
  decoderAttentionMaskShape
decoderAttentionMask, Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  crossAttentionMaskShape
crossAttentionMask) =
    let decoderRelPosBias :: IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     decoderRelPosEncGradient
     decoderRelPosEncLayout
     decoderRelPosEncDevice
     decoderRelPosEncDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           decoderRelPosEncShape)))
decoderRelPosBias =
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  decoderRelPosGradient
  decoderRelPosLayout
  decoderRelPosDevice
  decoderRelPosDataType
  decoderRelPosShape
decoderRelPos
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward relPosEnc
tRelPosEnc
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
       (selectDim1 :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 SingI selectDim0, SingI selectDim1, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
transpose @('SelectDim ('ByIndex 2)) @('SelectDim ('ByIndex 3))
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
       (selectDim1 :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 SingI selectDim0, SingI selectDim1, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
transpose @('SelectDim ('ByIndex 1)) @('SelectDim ('ByIndex 2))
        decoderAttentionBias :: IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     (Or
        (Gradient RequiresGradient)
        decoderRelPosEncGradient
        decoderAttentionMaskGradient)
     (Unify
        (Layout LayoutType)
        decoderRelPosEncLayout
        decoderAttentionMaskLayout)
     (Unify
        (Device (DeviceType Natural))
        decoderRelPosEncDevice
        decoderAttentionMaskDevice)
     (Unify
        (DataType DType)
        decoderRelPosEncDataType
        decoderAttentionMaskDataType)
     (BroadcastShapesF
        (TransposeF
           ('SelectDim ('ByIndex 1))
           ('SelectDim ('ByIndex 2))
           (TransposeF
              ('SelectDim ('ByIndex 2))
              ('SelectDim ('ByIndex 3))
              decoderRelPosEncShape))
        unsqueezedDecoderAttentionMaskShape))
decoderAttentionBias =
          IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     decoderRelPosEncGradient
     decoderRelPosEncLayout
     decoderRelPosEncDevice
     decoderRelPosEncDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           decoderRelPosEncShape)))
decoderRelPosBias
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (selectDim :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @1) Tensor
  decoderAttentionMaskGradient
  decoderAttentionMaskLayout
  decoderAttentionMaskDevice
  decoderAttentionMaskDataType
  decoderAttentionMaskShape
decoderAttentionMask forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType)
       (device' :: Device (DeviceType Natural))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
       (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
add
        crossAttentionBias :: IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     crossAttentionMaskGradient
     crossAttentionMaskLayout
     crossAttentionMaskDevice
     crossAttentionMaskDataType
     unsqueezedCrossAttentionMaskShape)
crossAttentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Natural)])
       (shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  crossAttentionMaskShape
crossAttentionMask
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
decoderInput
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \tensor1
decoderInput' ->
                     IxStateT
  m
  (Generator generatorDevice1)
  (Generator generatorDevice2)
  (Tensor
     (Or
        (Gradient RequiresGradient)
        decoderRelPosEncGradient
        decoderAttentionMaskGradient)
     (Unify
        (Layout LayoutType)
        decoderRelPosEncLayout
        decoderAttentionMaskLayout)
     (Unify
        (Device (DeviceType Natural))
        decoderRelPosEncDevice
        decoderAttentionMaskDevice)
     (Unify
        (DataType DType)
        decoderRelPosEncDataType
        decoderAttentionMaskDataType)
     (BroadcastShapesF
        (TransposeF
           ('SelectDim ('ByIndex 1))
           ('SelectDim ('ByIndex 2))
           (TransposeF
              ('SelectDim ('ByIndex 2))
              ('SelectDim ('ByIndex 3))
              decoderRelPosEncShape))
        unsqueezedDecoderAttentionMaskShape))
decoderAttentionBias
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
  (Or
     (Gradient RequiresGradient)
     decoderRelPosEncGradient
     decoderAttentionMaskGradient)
  (Unify
     (Layout LayoutType)
     decoderRelPosEncLayout
     decoderAttentionMaskLayout)
  (Unify
     (Device (DeviceType Natural))
     decoderRelPosEncDevice
     decoderAttentionMaskDevice)
  (Unify
     (DataType DType)
     decoderRelPosEncDataType
     decoderAttentionMaskDataType)
  (BroadcastShapesF
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           decoderRelPosEncShape))
     unsqueezedDecoderAttentionMaskShape)
decoderAttentionBias' ->
                                IxStateT
  m
  (Generator generatorDevice2)
  (Generator generatorDevice2)
  (Tensor
     crossAttentionMaskGradient
     crossAttentionMaskLayout
     crossAttentionMaskDevice
     crossAttentionMaskDataType
     unsqueezedCrossAttentionMaskShape)
crossAttentionBias
                                  forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  unsqueezedCrossAttentionMaskShape
crossAttentionBias' ->
                                           forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
                                             forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
                                               stack
tStack
                                               ( tensor1
decoderInput',
                                                 Tensor
  encoderOutputGradient
  encoderOutputLayout
  encoderOutputDevice
  encoderOutputDataType
  encoderOutputShape
encoderOutput,
                                                 Tensor
  (Or
     (Gradient RequiresGradient)
     decoderRelPosEncGradient
     decoderAttentionMaskGradient)
  (Unify
     (Layout LayoutType)
     decoderRelPosEncLayout
     decoderAttentionMaskLayout)
  (Unify
     (Device (DeviceType Natural))
     decoderRelPosEncDevice
     decoderAttentionMaskDevice)
  (Unify
     (DataType DType)
     decoderRelPosEncDataType
     decoderAttentionMaskDataType)
  (BroadcastShapesF
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (TransposeF
           ('SelectDim ('ByIndex 2))
           ('SelectDim ('ByIndex 3))
           decoderRelPosEncShape))
     unsqueezedDecoderAttentionMaskShape)
decoderAttentionBias',
                                                 Tensor
  crossAttentionMaskGradient
  crossAttentionMaskLayout
  crossAttentionMaskDevice
  crossAttentionMaskDataType
  unsqueezedCrossAttentionMaskShape
crossAttentionBias'
                                               )
                                       )
                            )
                 )
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout