{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}

module Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder where

import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (fromSing))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Index.Type (Index (NegativeIndex), SIndex (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Functional.NonLinearActivation (SoftmaxF, logSoftmax)
import Torch.GraduallyTyped.NN.Sparse (Embedding (..), EmbeddingSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GLMHead (GLMHeadF, lmHeadSpec)
import Torch.GraduallyTyped.NN.Transformer.GTransformer (TransformerDecoderF, TransformerEncoderF, transformerDecoderSpec, transformerEncoderSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (..), STransformerStyle (..), ShiftRight, TransformerHead (WithLMHead, WithoutHead), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (SNothing))
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim, SSelectDim (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Indexing (IndexDims, IndexType (..), Indices (..), SIndexType (..), SIndices (..), (!))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (GatherDimF, SqueezeDimF, UnsqueezeF, sGatherDim, sSqueezeDim, sUnsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mulScalar)
import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (MeanAllCheckF, meanAll)
import Torch.GraduallyTyped.Tensor.Type (Tensor ())
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Prelude hiding (head)

-- | Data type that is used to represent whether the encoder-decoder transformer model has a scaled embedding.
data EncoderDecoderTransformerHasEmbedScaling
  = EncoderDecoderTransformerWithEmbedScaling
  | EncoderDecoderTransformerWithoutEmbedScaling
  deriving stock (EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c/= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
== :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c== :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
Eq, Eq EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Ordering
EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
$cmin :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
max :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
$cmax :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
>= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c>= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
> :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c> :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
<= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c<= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
< :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c< :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
compare :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Ordering
$ccompare :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Ordering
Ord, Int -> EncoderDecoderTransformerHasEmbedScaling -> ShowS
[EncoderDecoderTransformerHasEmbedScaling] -> ShowS
EncoderDecoderTransformerHasEmbedScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderDecoderTransformerHasEmbedScaling] -> ShowS
$cshowList :: [EncoderDecoderTransformerHasEmbedScaling] -> ShowS
show :: EncoderDecoderTransformerHasEmbedScaling -> String
$cshow :: EncoderDecoderTransformerHasEmbedScaling -> String
showsPrec :: Int -> EncoderDecoderTransformerHasEmbedScaling -> ShowS
$cshowsPrec :: Int -> EncoderDecoderTransformerHasEmbedScaling -> ShowS
Show, forall x.
Rep EncoderDecoderTransformerHasEmbedScaling x
-> EncoderDecoderTransformerHasEmbedScaling
forall x.
EncoderDecoderTransformerHasEmbedScaling
-> Rep EncoderDecoderTransformerHasEmbedScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep EncoderDecoderTransformerHasEmbedScaling x
-> EncoderDecoderTransformerHasEmbedScaling
$cfrom :: forall x.
EncoderDecoderTransformerHasEmbedScaling
-> Rep EncoderDecoderTransformerHasEmbedScaling x
Generic)

type instance ModelSpec EncoderDecoderTransformerHasEmbedScaling = EncoderDecoderTransformerHasEmbedScaling

instance HasInitialize EncoderDecoderTransformerHasEmbedScaling generatorDevice EncoderDecoderTransformerHasEmbedScaling generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec EncoderDecoderTransformerHasEmbedScaling
-> Generator generatorDevice
-> m (EncoderDecoderTransformerHasEmbedScaling,
      Generator generatorDevice)
initialize ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling, Generator generatorDevice
g)

instance HasStateDict EncoderDecoderTransformerHasEmbedScaling where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec EncoderDecoderTransformerHasEmbedScaling
-> StateDictKey -> m EncoderDecoderTransformerHasEmbedScaling
fromStateDict ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> EncoderDecoderTransformerHasEmbedScaling -> m ()
toStateDict StateDictKey
_ EncoderDecoderTransformerHasEmbedScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Generic encoder-decoder transformer model.
-- This is a model that can be used to encode and decode sequences of variable length.
--
-- - @inputEmbedDim@: the dimension of the input embedding.
-- - @encoder@: a transformer encoder.
-- - @decoder@: a transformer decoder.
-- - @sharedEmbedding@: a shared embedding layer.
-- - @head@: a head layer for the output.
data
  GEncoderDecoderTransformer
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (encoder :: Type)
    (decoder :: Type)
    (sharedEmbedding :: Type)
    (head :: Type)
  where
  GEncoderDecoderTransformer ::
    forall inputEmbedDim encoder decoder sharedEmbedding head.
    { -- | input embedding dim for scaling
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> SDim inputEmbedDim
edtInputEmbedDim :: SDim inputEmbedDim,
      -- | encoder
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> encoder
edtEncoder :: encoder,
      -- | decoder
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> decoder
edtDecoder :: decoder,
      -- | embedding shared between encoder and decoder
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> sharedEmbedding
edtSharedEmbedding :: sharedEmbedding,
      -- | transformer head
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> head
edtHead :: head,
      -- | embedding scaling
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling :: EncoderDecoderTransformerHasEmbedScaling
    } ->
    GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head
  deriving stock (Int
-> GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
Int
-> GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
[GEncoderDecoderTransformer
   inputEmbedDim encoder decoder sharedEmbedding head]
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> String
showList :: [GEncoderDecoderTransformer
   inputEmbedDim encoder decoder sharedEmbedding head]
-> ShowS
$cshowList :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
[GEncoderDecoderTransformer
   inputEmbedDim encoder decoder sharedEmbedding head]
-> ShowS
show :: GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> String
$cshow :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> String
showsPrec :: Int
-> GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
$cshowsPrec :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
Int
-> GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head x.
Rep
  (GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head)
  x
-> GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head x.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> Rep
     (GEncoderDecoderTransformer
        inputEmbedDim encoder decoder sharedEmbedding head)
     x
$cto :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head x.
Rep
  (GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head)
  x
-> GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head
$cfrom :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head x.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> Rep
     (GEncoderDecoderTransformer
        inputEmbedDim encoder decoder sharedEmbedding head)
     x
Generic)

type instance
  ModelSpec (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head) =
    GEncoderDecoderTransformer inputEmbedDim (ModelSpec encoder) (ModelSpec decoder) (ModelSpec sharedEmbedding) (ModelSpec head)

type family
  GEncoderDecoderTransformerF
    (style :: TransformerStyle)
    (transformerHead :: TransformerHead)
    (numEncoderLayers :: Nat)
    (numDecoderLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  GEncoderDecoderTransformerF style transformerHead numEncoderLayers numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim hasDropout =
    GEncoderDecoderTransformer
      inputEmbedDim
      (EDTEncoderF style numEncoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)
      (EDTDecoderF style numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)
      (EDTSharedEmbeddingF style gradient device dataType inputEmbedDim vocabDim)
      (EDTHeadF style transformerHead gradient device dataType inputEmbedDim vocabDim)

-- | Specifies the encoder of the encoder-decoder transformer model.
type family
  EDTEncoderF
    (style :: TransformerStyle)
    (numEncoderLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  EDTEncoderF style numEncoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout =
    NamedModel (TransformerEncoderF style numEncoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)

-- | Specifies the decoder of the encoder-decoder transformer model.
type family
  EDTDecoderF
    (style :: TransformerStyle)
    (numDecoderLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  EDTDecoderF style numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout =
    NamedModel (TransformerDecoderF style numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim inputEmbedDim ffnDim posEncDim hasDropout)

-- | Specifies the shared embedding layer of the encoder-decoder transformer model.
type family
  EDTSharedEmbeddingF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat))
  where
  EDTSharedEmbeddingF _ gradient device dataType inputEmbedDim vocabDim =
    NamedModel (Embedding gradient ('Layout 'Dense) device dataType vocabDim inputEmbedDim 'Nothing)

-- | Specifies the head of the encoder-decoder transformer model.
type family
  EDTHeadF
    (style :: TransformerStyle)
    (transformerHead :: TransformerHead)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  EDTHeadF style 'WithoutHead gradient device dataType inputEmbedDim vocabDim =
    ()
  EDTHeadF style 'WithLMHead gradient device dataType inputEmbedDim vocabDim =
    NamedModel (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)

-- | Specifies the parameters of an encoder-decoder transformer model.
encoderDecoderTransformerSpec ::
  forall style transformerHead numEncoderLayers numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim hasDropout.
  -- | the style of the encoder-decoder transformer model, e.g. 'ST5', 'SBART', etc.
  STransformerStyle style ->
  -- | the head of the encoder-decoder transformer model.
  STransformerHead transformerHead ->
  -- | the number of encoder layers of the encoder-decoder transformer model.
  SNat numEncoderLayers ->
  -- | the number of decoder layers of the encoder-decoder transformer model.
  SNat numDecoderLayers ->
  -- | whether or not to compute the gradient of the model parameters
  SGradient gradient ->
  -- | the computational device on which the model is allocated.
  SDevice device ->
  -- | the data type of the model parameters.
  SDataType dataType ->
  -- | the dimension of all transformer heads in the encoder-decoder transformer model.
  SDim headDim ->
  -- | the dimension of the transformer head embeddings.
  SDim headEmbedDim ->
  -- | the dimension of the transformer embeddings.
  SDim embedDim ->
  -- | the dimension of the input embeddings for both the encoder and the decoder.
  SDim inputEmbedDim ->
  -- | the dimension of the feed-forward network.
  SDim ffnDim ->
  -- | the dimension of the positional embeddings.
  SDim posEncDim ->
  -- | the dimension of the vocabulary.
  SDim vocabDim ->
  -- | whether or not to use dropout.
  SHasDropout hasDropout ->
  -- | the dropout rate.
  Double ->
  -- | the epsilon value for numerical stability of the layer normalization.
  Double ->
  -- | the parameter specification of an encoder-decoder transformer model.
  ModelSpec (GEncoderDecoderTransformerF style transformerHead numEncoderLayers numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim hasDropout)
encoderDecoderTransformerSpec :: forall (style :: TransformerStyle)
       (transformerHead :: TransformerHead) (numEncoderLayers :: Nat)
       (numDecoderLayers :: Nat) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (posEncDim :: Dim (Name Symbol) (Size Nat))
       (vocabDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GEncoderDecoderTransformerF
        style
        transformerHead
        numEncoderLayers
        numDecoderLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        vocabDim
        hasDropout)
encoderDecoderTransformerSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead SNat numEncoderLayers
numEncoderLayers SNat numDecoderLayers
numDecoderLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SDim vocabDim
vocabDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let encoderSpec :: STransformerStyle style
-> NamedModel
     (GTransformer
        (ModelSpec
           (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
        (ModelSpec
           (TERelPosEncF style gradient device dataType headDim posEncDim))
        (ModelSpec
           (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEInitialDropoutF style hasDropout))
        (NamedModel
           (GTransformerStack
              (VectorSpec
                 numEncoderLayers
                 (GTransformerBlock
                    (NamedModel
                       (GSelfAttention
                          (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (SADropoutF style hasDropout)
                          (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    ()
                    (NamedModel
                       (GTransformerFeedForwardNetwork
                          (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                          (FFNInputTransformationF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNActivationF style)
                          (FFNActivationDropoutF style hasDropout)
                          (FFNOutputProjectionF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNOutputDropoutF style hasDropout)
                          (FFNOutputLayerNormF
                             style gradient device dataType inputEmbedDim)))))))
        (ModelSpec
           (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numEncoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'T5
ST5
      encoderSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numEncoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'ByT5
SByT5
      encoderSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numEncoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'BART
SBART
      encoderSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numEncoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'MBART
SMBART
      encoderSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numEncoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'Pegasus
SPegasus
      encoderSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      encoderSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      encoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      decoderSpec :: STransformerStyle style
-> NamedModel
     (GTransformer
        (ModelSpec
           (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
        (ModelSpec
           (TDRelPosEncF style gradient device dataType headDim posEncDim))
        (ModelSpec
           (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TDInitialDropoutF style hasDropout))
        (NamedModel
           (GTransformerStack
              (VectorSpec
                 numDecoderLayers
                 (GTransformerBlock
                    (NamedModel
                       (GSelfAttention
                          (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (SADropoutF style hasDropout)
                          (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    (NamedModel
                       (GCrossAttention
                          (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (CADropoutF style hasDropout)
                          (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    (NamedModel
                       (GTransformerFeedForwardNetwork
                          (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                          (FFNInputTransformationF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNActivationF style)
                          (FFNActivationDropoutF style hasDropout)
                          (FFNOutputProjectionF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNOutputDropoutF style hasDropout)
                          (FFNOutputLayerNormF
                             style gradient device dataType inputEmbedDim)))))))
        (ModelSpec
           (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TDFinalDropoutF style hasDropout)))
decoderSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TDRelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numDecoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GCrossAttention
                       (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (CADropoutF style hasDropout)
                       (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'T5
ST5
      decoderSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TDRelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numDecoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GCrossAttention
                       (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (CADropoutF style hasDropout)
                       (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'ByT5
SByT5
      decoderSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TDRelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numDecoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GCrossAttention
                       (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (CADropoutF style hasDropout)
                       (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'BART
SBART
      decoderSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TDRelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numDecoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GCrossAttention
                       (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (CADropoutF style hasDropout)
                       (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'MBART
SMBART
      decoderSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TDRelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numDecoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GCrossAttention
                       (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (CADropoutF style hasDropout)
                       (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'Pegasus
SPegasus
      decoderSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      decoderSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      decoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      sharedEmbeddingSpec :: STransformerStyle style
-> NamedModel
     (EmbeddingSpec
        gradient
        ('Layout 'Dense)
        device
        dataType
        vocabDim
        inputEmbedDim
        'Nothing)
sharedEmbeddingSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"shared." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
sharedEmbeddingSpec'
      sharedEmbeddingSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"shared." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
sharedEmbeddingSpec'
      sharedEmbeddingSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.shared." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
sharedEmbeddingSpec'
      sharedEmbeddingSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.shared." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
sharedEmbeddingSpec'
      sharedEmbeddingSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.shared." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
sharedEmbeddingSpec'
      sharedEmbeddingSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      sharedEmbeddingSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      sharedEmbeddingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      headSpec :: STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
     (EDTHeadF
        style
        transformerHead
        gradient
        device
        dataType
        inputEmbedDim
        vocabDim)
headSpec STransformerStyle style
ST5 STransformerHead transformerHead
SWithoutHead = ()
      headSpec STransformerStyle style
ST5 STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'T5
ST5
      headSpec STransformerStyle style
SByT5 STransformerHead transformerHead
SWithoutHead = ()
      headSpec STransformerStyle style
SByT5 STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'ByT5
SByT5
      headSpec STransformerStyle style
SBART STransformerHead transformerHead
SWithoutHead = ()
      headSpec STransformerStyle style
SBART STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'BART
SBART
      headSpec STransformerStyle style
SMBART STransformerHead transformerHead
SWithoutHead = ()
      headSpec STransformerStyle style
SMBART STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'MBART
SMBART
      headSpec STransformerStyle style
SPegasus STransformerHead transformerHead
SWithoutHead = ()
      headSpec STransformerStyle style
SPegasus STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'Pegasus
SPegasus
      headSpec STransformerStyle style
SBERT STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      headSpec STransformerStyle style
SRoBERTa STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      headSpec STransformerStyle style
SGPT2 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      embedScalingSpec :: STransformerStyle style -> EncoderDecoderTransformerHasEmbedScaling
      embedScalingSpec :: STransformerStyle style -> EncoderDecoderTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
ST5 = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
      embedScalingSpec STransformerStyle style
SByT5 = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
      embedScalingSpec STransformerStyle style
SBART = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
      embedScalingSpec STransformerStyle style
SMBART = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
      embedScalingSpec STransformerStyle style
SPegasus = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithEmbedScaling
      embedScalingSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      embedScalingSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      embedScalingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
SDim inputEmbedDim
-> encoder
-> decoder
-> sharedEmbedding
-> head
-> EncoderDecoderTransformerHasEmbedScaling
-> GEncoderDecoderTransformer
     inputEmbedDim encoder decoder sharedEmbedding head
GEncoderDecoderTransformer SDim inputEmbedDim
inputEmbedDim (STransformerStyle style
-> NamedModel
     (GTransformer
        (ModelSpec
           (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
        (ModelSpec
           (TERelPosEncF style gradient device dataType headDim posEncDim))
        (ModelSpec
           (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEInitialDropoutF style hasDropout))
        (NamedModel
           (GTransformerStack
              (VectorSpec
                 numEncoderLayers
                 (GTransformerBlock
                    (NamedModel
                       (GSelfAttention
                          (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (SADropoutF style hasDropout)
                          (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    ()
                    (NamedModel
                       (GTransformerFeedForwardNetwork
                          (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                          (FFNInputTransformationF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNActivationF style)
                          (FFNActivationDropoutF style hasDropout)
                          (FFNOutputProjectionF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNOutputDropoutF style hasDropout)
                          (FFNOutputLayerNormF
                             style gradient device dataType inputEmbedDim)))))))
        (ModelSpec
           (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
     (GTransformer
        (ModelSpec
           (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
        (ModelSpec
           (TDRelPosEncF style gradient device dataType headDim posEncDim))
        (ModelSpec
           (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TDInitialDropoutF style hasDropout))
        (NamedModel
           (GTransformerStack
              (VectorSpec
                 numDecoderLayers
                 (GTransformerBlock
                    (NamedModel
                       (GSelfAttention
                          (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (SADropoutF style hasDropout)
                          (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    (NamedModel
                       (GCrossAttention
                          (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (CADropoutF style hasDropout)
                          (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    (NamedModel
                       (GTransformerFeedForwardNetwork
                          (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                          (FFNInputTransformationF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNActivationF style)
                          (FFNActivationDropoutF style hasDropout)
                          (FFNOutputProjectionF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNOutputDropoutF style hasDropout)
                          (FFNOutputLayerNormF
                             style gradient device dataType inputEmbedDim)))))))
        (ModelSpec
           (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TDFinalDropoutF style hasDropout)))
decoderSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
     (EmbeddingSpec
        gradient
        ('Layout 'Dense)
        device
        dataType
        vocabDim
        inputEmbedDim
        'Nothing)
sharedEmbeddingSpec STransformerStyle style
style) (STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
     (EDTHeadF
        style
        transformerHead
        gradient
        device
        dataType
        inputEmbedDim
        vocabDim)
headSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead) (STransformerStyle style -> EncoderDecoderTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
style)
  where
    encoderSpec' :: _
    encoderSpec' :: STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numEncoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Nat)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (posEncDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (TransformerEncoderF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        hasDropout)
transformerEncoderSpec STransformerStyle style
style' SNat numEncoderLayers
numEncoderLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    decoderSpec' :: _
    decoderSpec' :: STransformerStyle style
-> GTransformer
     (ModelSpec
        (TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TDRelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TDInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numDecoderLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GCrossAttention
                       (CAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (CADropoutF style hasDropout)
                       (CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TDFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Nat)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (decoderInputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (encoderOutputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (posEncDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim decoderInputEmbedDim
-> SDim encoderOutputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (TransformerDecoderF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        decoderInputEmbedDim
        encoderOutputEmbedDim
        ffnDim
        posEncDim
        hasDropout)
transformerDecoderSpec STransformerStyle style
style' SNat numDecoderLayers
numDecoderLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    sharedEmbeddingSpec' :: EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
sharedEmbeddingSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim vocabDim
vocabDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing
    headSpec' :: _
    headSpec' :: STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (vocabDim :: Dim (Name Symbol) (Size Nat)).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputEmbedDim
-> SDim vocabDim
-> Double
-> ModelSpec
     (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim Double
eps

instance
  ( HasInitialize encoder generatorDevice encoder' generatorDevice0,
    HasInitialize decoder generatorDevice0 decoder' generatorDevice1,
    HasInitialize sharedEmbedding generatorDevice1 sharedEmbedding' generatorDevice2,
    HasInitialize head generatorDevice2 head' generatorOutputDevice
  ) =>
  HasInitialize
    (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
    generatorDevice
    (GEncoderDecoderTransformer inputEmbedDim encoder' decoder' sharedEmbedding' head')
    generatorOutputDevice

instance
  ( HasStateDict encoder,
    HasStateDict decoder,
    HasStateDict sharedEmbedding,
    HasStateDict head
  ) =>
  HasStateDict (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)

data
  GSimplifiedEncoderDecoderTransformer
    (model :: Type)
    (mkPos :: Type)
    (mkDecoderPos :: Type)
    (mkPaddingMask :: Type)
    (mkAttentionMask :: Type)
    (mkCrossAttentionMask :: Type)
    (mkDecoderAttentionMask :: Type)
  where
  GSimplifiedEncoderDecoderTransformer ::
    forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask.
    { -- | encoder-decoder model
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> model
sedtModel :: model,
      -- | shift for decoder input
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> ShiftRight Int
sedtDecoderInputShift :: ShiftRight Int,
      -- | shift for padding mask
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> ShiftRight Int
sedtPaddingMaskShift :: ShiftRight Int,
      -- | make encoder input positions
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkPos
sedtMkPos :: mkPos,
      -- | make decoder input position
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkDecoderPos
sedtMkDecoderPos :: mkDecoderPos,
      -- | make padding mask
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkPaddingMask
sedtMkPaddingMask :: mkPaddingMask,
      -- | make attention mask
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkAttentionMask
sedtMkAttentionMask :: mkAttentionMask,
      -- | make cross-attention mask
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkCrossAttentionMask
sedtMkCrossAttentionMask :: mkCrossAttentionMask,
      -- | make decoder attention mask
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkDecoderAttentionMask
sedtMkDecoderAttentionMask :: mkDecoderAttentionMask
    } ->
    GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask
  deriving stock (GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Eq model, Eq mkPos, Eq mkDecoderPos, Eq mkPaddingMask,
 Eq mkAttentionMask, Eq mkCrossAttentionMask,
 Eq mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
/= :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
$c/= :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Eq model, Eq mkPos, Eq mkDecoderPos, Eq mkPaddingMask,
 Eq mkAttentionMask, Eq mkCrossAttentionMask,
 Eq mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
== :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
$c== :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Eq model, Eq mkPos, Eq mkDecoderPos, Eq mkPaddingMask,
 Eq mkAttentionMask, Eq mkCrossAttentionMask,
 Eq mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
Eq, GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {model} {mkPos} {mkDecoderPos} {mkPaddingMask}
       {mkAttentionMask} {mkCrossAttentionMask} {mkDecoderAttentionMask}.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
Eq
  (GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask)
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Ordering
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
min :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
$cmin :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
max :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
$cmax :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
>= :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
$c>= :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
> :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
$c> :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
<= :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
$c<= :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
< :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
$c< :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Bool
compare :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Ordering
$ccompare :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
 Ord mkAttentionMask, Ord mkCrossAttentionMask,
 Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> Ordering
Ord, Int
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
 Show mkAttentionMask, Show mkCrossAttentionMask,
 Show mkDecoderAttentionMask) =>
Int
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> ShowS
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
 Show mkAttentionMask, Show mkCrossAttentionMask,
 Show mkDecoderAttentionMask) =>
[GSimplifiedEncoderDecoderTransformer
   model
   mkPos
   mkDecoderPos
   mkPaddingMask
   mkAttentionMask
   mkCrossAttentionMask
   mkDecoderAttentionMask]
-> ShowS
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
 Show mkAttentionMask, Show mkCrossAttentionMask,
 Show mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> String
showList :: [GSimplifiedEncoderDecoderTransformer
   model
   mkPos
   mkDecoderPos
   mkPaddingMask
   mkAttentionMask
   mkCrossAttentionMask
   mkDecoderAttentionMask]
-> ShowS
$cshowList :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
 Show mkAttentionMask, Show mkCrossAttentionMask,
 Show mkDecoderAttentionMask) =>
[GSimplifiedEncoderDecoderTransformer
   model
   mkPos
   mkDecoderPos
   mkPaddingMask
   mkAttentionMask
   mkCrossAttentionMask
   mkDecoderAttentionMask]
-> ShowS
show :: GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> String
$cshow :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
 Show mkAttentionMask, Show mkCrossAttentionMask,
 Show mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> String
showsPrec :: Int
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> ShowS
$cshowsPrec :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
 Show mkAttentionMask, Show mkCrossAttentionMask,
 Show mkDecoderAttentionMask) =>
Int
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask x.
Rep
  (GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask)
  x
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask x.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> Rep
     (GSimplifiedEncoderDecoderTransformer
        model
        mkPos
        mkDecoderPos
        mkPaddingMask
        mkAttentionMask
        mkCrossAttentionMask
        mkDecoderAttentionMask)
     x
$cto :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask x.
Rep
  (GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask)
  x
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
$cfrom :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask x.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> Rep
     (GSimplifiedEncoderDecoderTransformer
        model
        mkPos
        mkDecoderPos
        mkPaddingMask
        mkAttentionMask
        mkCrossAttentionMask
        mkDecoderAttentionMask)
     x
Generic)

type instance
  ModelSpec (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask) =
    GSimplifiedEncoderDecoderTransformer (ModelSpec model) (ModelSpec mkPos) (ModelSpec mkDecoderPos) (ModelSpec mkPaddingMask) (ModelSpec mkAttentionMask) (ModelSpec mkCrossAttentionMask) (ModelSpec mkDecoderAttentionMask)

instance
  ( HasInitialize model generatorDevice model' generatorDevice0,
    HasInitialize mkPos generatorDevice0 mkPos' generatorDevice1,
    HasInitialize mkDecoderPos generatorDevice1 mkDecoderPos' generatorDevice2,
    HasInitialize mkPaddingMask generatorDevice2 mkPaddingMask' generatorDevice3,
    HasInitialize mkAttentionMask generatorDevice3 mkAttentionMask' generatorDevice4,
    HasInitialize mkCrossAttentionMask generatorDevice4 mkCrossAttentionMask' generatorDevice5,
    HasInitialize mkDecoderAttentionMask generatorDevice5 mkDecoderAttentionMask' generatorOutputDevice
  ) =>
  HasInitialize
    (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
    generatorDevice
    (GSimplifiedEncoderDecoderTransformer model' mkPos' mkDecoderPos' mkPaddingMask' mkAttentionMask' mkCrossAttentionMask' mkDecoderAttentionMask')
    generatorOutputDevice

instance
  ( HasStateDict model,
    HasStateDict mkPos,
    HasStateDict mkDecoderPos,
    HasStateDict mkPaddingMask,
    HasStateDict mkAttentionMask,
    HasStateDict mkCrossAttentionMask,
    HasStateDict mkDecoderAttentionMask
  ) =>
  HasStateDict (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)

-- | Input data type for use with an encoder-decoder transformer.
-- Use this for training.
data EncoderDecoderTransformerInput input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask where
  EncoderDecoderTransformerInput ::
    forall input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask.
    { forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> input
edtInput :: input,
      forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> decoderInput
edtDecoderInput :: decoderInput,
      forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> pos
edtPos :: pos,
      forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> decoderPos
edtDecoderPos :: decoderPos,
      forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> attentionMask
edtAttentionMask :: attentionMask,
      forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> decoderAttentionMask
edtDecoderAttentionMask :: decoderAttentionMask,
      forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> crossAttentionMask
edtCrossAttentionMask :: crossAttentionMask
    } ->
    EncoderDecoderTransformerInput input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask
  deriving stock (EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Eq input, Eq decoderInput, Eq pos, Eq decoderPos,
 Eq attentionMask, Eq decoderAttentionMask,
 Eq crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
/= :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c/= :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Eq input, Eq decoderInput, Eq pos, Eq decoderPos,
 Eq attentionMask, Eq decoderAttentionMask,
 Eq crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
== :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c== :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Eq input, Eq decoderInput, Eq pos, Eq decoderPos,
 Eq attentionMask, Eq decoderAttentionMask,
 Eq crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
Eq, EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {decoderInput} {pos} {decoderPos} {attentionMask}
       {decoderAttentionMask} {crossAttentionMask}.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
Eq
  (EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask)
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Ordering
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
min :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
$cmin :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
max :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
$cmax :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
>= :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c>= :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
> :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c> :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
<= :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c<= :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
< :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c< :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Bool
compare :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Ordering
$ccompare :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
 Ord attentionMask, Ord decoderAttentionMask,
 Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Ordering
Ord, Int
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
 Show attentionMask, Show decoderAttentionMask,
 Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> ShowS
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
 Show attentionMask, Show decoderAttentionMask,
 Show crossAttentionMask) =>
[EncoderDecoderTransformerInput
   input
   decoderInput
   pos
   decoderPos
   attentionMask
   decoderAttentionMask
   crossAttentionMask]
-> ShowS
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
 Show attentionMask, Show decoderAttentionMask,
 Show crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> String
showList :: [EncoderDecoderTransformerInput
   input
   decoderInput
   pos
   decoderPos
   attentionMask
   decoderAttentionMask
   crossAttentionMask]
-> ShowS
$cshowList :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
 Show attentionMask, Show decoderAttentionMask,
 Show crossAttentionMask) =>
[EncoderDecoderTransformerInput
   input
   decoderInput
   pos
   decoderPos
   attentionMask
   decoderAttentionMask
   crossAttentionMask]
-> ShowS
show :: EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> String
$cshow :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
 Show attentionMask, Show decoderAttentionMask,
 Show crossAttentionMask) =>
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> String
showsPrec :: Int
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> ShowS
$cshowsPrec :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
 Show attentionMask, Show decoderAttentionMask,
 Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask x.
Rep
  (EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask)
  x
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask x.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> Rep
     (EncoderDecoderTransformerInput
        input
        decoderInput
        pos
        decoderPos
        attentionMask
        decoderAttentionMask
        crossAttentionMask)
     x
$cto :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask x.
Rep
  (EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask)
  x
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
$cfrom :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask x.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> Rep
     (EncoderDecoderTransformerInput
        input
        decoderInput
        pos
        decoderPos
        attentionMask
        decoderAttentionMask
        crossAttentionMask)
     x
Generic)

data EncoderDecoderTransformerInput' input pos attentionMask where
  EncoderDecoderTransformerInput' ::
    forall input pos attentionMask.
    { forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> input
edtInput' :: input,
      forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> pos
edtPos' :: pos,
      forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask
-> attentionMask
edtAttentionMask' :: attentionMask
    } ->
    EncoderDecoderTransformerInput' input pos attentionMask
  deriving stock (EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input pos attentionMask.
(Eq input, Eq pos, Eq attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
/= :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c/= :: forall input pos attentionMask.
(Eq input, Eq pos, Eq attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
== :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c== :: forall input pos attentionMask.
(Eq input, Eq pos, Eq attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
Eq, EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {pos} {attentionMask}.
(Ord input, Ord pos, Ord attentionMask) =>
Eq (EncoderDecoderTransformerInput' input pos attentionMask)
forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
min :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
$cmin :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
max :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
$cmax :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
>= :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c>= :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
> :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c> :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
<= :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c<= :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
< :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c< :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
compare :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
$ccompare :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
Ord, Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
[EncoderDecoderTransformerInput' input pos attentionMask] -> ShowS
forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask -> String
showList :: [EncoderDecoderTransformerInput' input pos attentionMask] -> ShowS
$cshowList :: forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
[EncoderDecoderTransformerInput' input pos attentionMask] -> ShowS
show :: EncoderDecoderTransformerInput' input pos attentionMask -> String
$cshow :: forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask -> String
showsPrec :: Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
$cshowsPrec :: forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input pos attentionMask x.
Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
-> EncoderDecoderTransformerInput' input pos attentionMask
forall input pos attentionMask x.
EncoderDecoderTransformerInput' input pos attentionMask
-> Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
$cto :: forall input pos attentionMask x.
Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
-> EncoderDecoderTransformerInput' input pos attentionMask
$cfrom :: forall input pos attentionMask x.
EncoderDecoderTransformerInput' input pos attentionMask
-> Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
Generic)

data SimplifiedEncoderDecoderTransformerInput input decoderInput where
  SimplifiedEncoderDecoderTransformerInput ::
    forall input decoderInput.
    { forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> input
sedtInput :: input,
      forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> decoderInput
sedtDecoderInput :: decoderInput
    } ->
    SimplifiedEncoderDecoderTransformerInput input decoderInput
  deriving stock (SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input decoderInput.
(Eq input, Eq decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
/= :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c/= :: forall input decoderInput.
(Eq input, Eq decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
== :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c== :: forall input decoderInput.
(Eq input, Eq decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
Eq, SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {decoderInput}.
(Ord input, Ord decoderInput) =>
Eq (SimplifiedEncoderDecoderTransformerInput input decoderInput)
forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
min :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
$cmin :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
max :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
$cmax :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
>= :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c>= :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
> :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c> :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
<= :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c<= :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
< :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c< :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
compare :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
$ccompare :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input decoderInput.
(Show input, Show decoderInput) =>
Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
forall input decoderInput.
(Show input, Show decoderInput) =>
[SimplifiedEncoderDecoderTransformerInput input decoderInput]
-> ShowS
forall input decoderInput.
(Show input, Show decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> String
showList :: [SimplifiedEncoderDecoderTransformerInput input decoderInput]
-> ShowS
$cshowList :: forall input decoderInput.
(Show input, Show decoderInput) =>
[SimplifiedEncoderDecoderTransformerInput input decoderInput]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> String
$cshow :: forall input decoderInput.
(Show input, Show decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
$cshowsPrec :: forall input decoderInput.
(Show input, Show decoderInput) =>
Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input decoderInput x.
Rep (SimplifiedEncoderDecoderTransformerInput input decoderInput) x
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
forall input decoderInput x.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Rep
     (SimplifiedEncoderDecoderTransformerInput input decoderInput) x
$cto :: forall input decoderInput x.
Rep (SimplifiedEncoderDecoderTransformerInput input decoderInput) x
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
$cfrom :: forall input decoderInput x.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Rep
     (SimplifiedEncoderDecoderTransformerInput input decoderInput) x
Generic)

data SimplifiedEncoderDecoderTransformerInput' input where
  SimplifiedEncoderDecoderTransformerInput' ::
    forall input.
    { forall input.
SimplifiedEncoderDecoderTransformerInput' input -> input
sedtInput' :: input
    } ->
    SimplifiedEncoderDecoderTransformerInput' input
  deriving stock (SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
forall input.
Eq input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c/= :: forall input.
Eq input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
== :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c== :: forall input.
Eq input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
Eq, SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input}.
Ord input =>
Eq (SimplifiedEncoderDecoderTransformerInput' input)
forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
min :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
$cmin :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
max :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
$cmax :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
>= :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c>= :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
> :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c> :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
<= :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c<= :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
< :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c< :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
compare :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
$ccompare :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
Ord, Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
forall input.
Show input =>
Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
forall input.
Show input =>
[SimplifiedEncoderDecoderTransformerInput' input] -> ShowS
forall input.
Show input =>
SimplifiedEncoderDecoderTransformerInput' input -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SimplifiedEncoderDecoderTransformerInput' input] -> ShowS
$cshowList :: forall input.
Show input =>
[SimplifiedEncoderDecoderTransformerInput' input] -> ShowS
show :: SimplifiedEncoderDecoderTransformerInput' input -> String
$cshow :: forall input.
Show input =>
SimplifiedEncoderDecoderTransformerInput' input -> String
showsPrec :: Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
$cshowsPrec :: forall input.
Show input =>
Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input x.
Rep (SimplifiedEncoderDecoderTransformerInput' input) x
-> SimplifiedEncoderDecoderTransformerInput' input
forall input x.
SimplifiedEncoderDecoderTransformerInput' input
-> Rep (SimplifiedEncoderDecoderTransformerInput' input) x
$cto :: forall input x.
Rep (SimplifiedEncoderDecoderTransformerInput' input) x
-> SimplifiedEncoderDecoderTransformerInput' input
$cfrom :: forall input x.
SimplifiedEncoderDecoderTransformerInput' input
-> Rep (SimplifiedEncoderDecoderTransformerInput' input) x
Generic)

data SimplifiedEncoderDecoderTransformerTrainingInput input target where
  SimplifiedEncoderDecoderTransformerTrainingInput ::
    forall input target.
    { forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> input
sedtTrainingInput :: input,
      forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> target
sedtTarget :: target
    } ->
    SimplifiedEncoderDecoderTransformerTrainingInput input target
  deriving stock (SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input target.
(Eq input, Eq target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
/= :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c/= :: forall input target.
(Eq input, Eq target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
== :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c== :: forall input target.
(Eq input, Eq target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
Eq, SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {target}.
(Ord input, Ord target) =>
Eq (SimplifiedEncoderDecoderTransformerTrainingInput input target)
forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
min :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
$cmin :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
max :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
$cmax :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
>= :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c>= :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
> :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c> :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
<= :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c<= :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
< :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c< :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
compare :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
$ccompare :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input target.
(Show input, Show target) =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
forall input target.
(Show input, Show target) =>
[SimplifiedEncoderDecoderTransformerTrainingInput input target]
-> ShowS
forall input target.
(Show input, Show target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> String
showList :: [SimplifiedEncoderDecoderTransformerTrainingInput input target]
-> ShowS
$cshowList :: forall input target.
(Show input, Show target) =>
[SimplifiedEncoderDecoderTransformerTrainingInput input target]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> String
$cshow :: forall input target.
(Show input, Show target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
$cshowsPrec :: forall input target.
(Show input, Show target) =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input target x.
Rep
  (SimplifiedEncoderDecoderTransformerTrainingInput input target) x
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
forall input target x.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Rep
     (SimplifiedEncoderDecoderTransformerTrainingInput input target) x
$cto :: forall input target x.
Rep
  (SimplifiedEncoderDecoderTransformerTrainingInput input target) x
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
$cfrom :: forall input target x.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Rep
     (SimplifiedEncoderDecoderTransformerTrainingInput input target) x
Generic)

-- | Output data type for use with an encoder-decoder transformer.
data EncoderDecoderTransformerOutput decoderOutput encoderOutput where
  EncoderDecoderTransformerOutput ::
    forall decoderOutput encoderOutput.
    { forall decoderOutput encoderOutput.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> decoderOutput
edtDecoderOutput :: decoderOutput,
      forall decoderOutput encoderOutput.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> encoderOutput
edtEncoderOutput :: encoderOutput
    } ->
    EncoderDecoderTransformerOutput decoderOutput encoderOutput
  deriving stock (EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderOutput encoderOutput.
(Eq decoderOutput, Eq encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
/= :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c/= :: forall decoderOutput encoderOutput.
(Eq decoderOutput, Eq encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
== :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c== :: forall decoderOutput encoderOutput.
(Eq decoderOutput, Eq encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
Eq, EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderOutput} {encoderOutput}.
(Ord decoderOutput, Ord encoderOutput) =>
Eq (EncoderDecoderTransformerOutput decoderOutput encoderOutput)
forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
min :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
$cmin :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
max :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
$cmax :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
>= :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c>= :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
> :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c> :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
<= :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c<= :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
< :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c< :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
compare :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
$ccompare :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
Ord, Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
[EncoderDecoderTransformerOutput decoderOutput encoderOutput]
-> ShowS
forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> String
showList :: [EncoderDecoderTransformerOutput decoderOutput encoderOutput]
-> ShowS
$cshowList :: forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
[EncoderDecoderTransformerOutput decoderOutput encoderOutput]
-> ShowS
show :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> String
$cshow :: forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> String
showsPrec :: Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
$cshowsPrec :: forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderOutput encoderOutput x.
Rep (EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
forall decoderOutput encoderOutput x.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Rep
     (EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
$cto :: forall decoderOutput encoderOutput x.
Rep (EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
$cfrom :: forall decoderOutput encoderOutput x.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Rep
     (EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
Generic)

data EncoderDecoderTransformerOutput' encoderOutput where
  EncoderDecoderTransformerOutput' ::
    forall encoderOutput.
    { forall encoderOutput.
EncoderDecoderTransformerOutput' encoderOutput -> encoderOutput
edtEncoderOutput' :: encoderOutput
    } ->
    EncoderDecoderTransformerOutput' encoderOutput
  deriving stock (EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
forall encoderOutput.
Eq encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c/= :: forall encoderOutput.
Eq encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
== :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c== :: forall encoderOutput.
Eq encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
Eq, EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {encoderOutput}.
Ord encoderOutput =>
Eq (EncoderDecoderTransformerOutput' encoderOutput)
forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
min :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
$cmin :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
max :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
$cmax :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
>= :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c>= :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
> :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c> :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
<= :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c<= :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
< :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c< :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
compare :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
$ccompare :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
Ord, Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
forall encoderOutput.
Show encoderOutput =>
Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
forall encoderOutput.
Show encoderOutput =>
[EncoderDecoderTransformerOutput' encoderOutput] -> ShowS
forall encoderOutput.
Show encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderDecoderTransformerOutput' encoderOutput] -> ShowS
$cshowList :: forall encoderOutput.
Show encoderOutput =>
[EncoderDecoderTransformerOutput' encoderOutput] -> ShowS
show :: EncoderDecoderTransformerOutput' encoderOutput -> String
$cshow :: forall encoderOutput.
Show encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput -> String
showsPrec :: Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
$cshowsPrec :: forall encoderOutput.
Show encoderOutput =>
Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall encoderOutput x.
Rep (EncoderDecoderTransformerOutput' encoderOutput) x
-> EncoderDecoderTransformerOutput' encoderOutput
forall encoderOutput x.
EncoderDecoderTransformerOutput' encoderOutput
-> Rep (EncoderDecoderTransformerOutput' encoderOutput) x
$cto :: forall encoderOutput x.
Rep (EncoderDecoderTransformerOutput' encoderOutput) x
-> EncoderDecoderTransformerOutput' encoderOutput
$cfrom :: forall encoderOutput x.
EncoderDecoderTransformerOutput' encoderOutput
-> Rep (EncoderDecoderTransformerOutput' encoderOutput) x
Generic)

data SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask where
  SimplifiedEncoderDecoderTransformerOutput ::
    forall decoderOutput encoderOutput decoderInput inputPaddingMask.
    { forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderOutput
sedtDecoderOutput :: decoderOutput,
      forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> encoderOutput
sedtEncoderOutput :: encoderOutput,
      forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderInput
sedtOriginalDecoderInput :: decoderInput,
      forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> inputPaddingMask
sedtInputPaddingMask :: inputPaddingMask
    } ->
    SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask
  deriving stock (SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Eq decoderOutput, Eq encoderOutput, Eq decoderInput,
 Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
/= :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c/= :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Eq decoderOutput, Eq encoderOutput, Eq decoderInput,
 Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
== :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c== :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Eq decoderOutput, Eq encoderOutput, Eq decoderInput,
 Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
Eq, SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderOutput} {encoderOutput} {decoderInput}
       {inputPaddingMask}.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
Eq
  (SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask)
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
min :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
$cmin :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
max :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
$cmax :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
>= :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c>= :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
> :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c> :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
<= :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c<= :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
< :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c< :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
compare :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
$ccompare :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
 Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
 Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
 Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput
   decoderOutput encoderOutput decoderInput inputPaddingMask]
-> ShowS
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
 Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> String
showList :: [SimplifiedEncoderDecoderTransformerOutput
   decoderOutput encoderOutput decoderInput inputPaddingMask]
-> ShowS
$cshowList :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
 Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput
   decoderOutput encoderOutput decoderInput inputPaddingMask]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> String
$cshow :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
 Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
$cshowsPrec :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
 Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
Rep
  (SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask)
  x
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> Rep
     (SimplifiedEncoderDecoderTransformerOutput
        decoderOutput encoderOutput decoderInput inputPaddingMask)
     x
$cto :: forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
Rep
  (SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask)
  x
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
$cfrom :: forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> Rep
     (SimplifiedEncoderDecoderTransformerOutput
        decoderOutput encoderOutput decoderInput inputPaddingMask)
     x
Generic)

data SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask where
  SimplifiedEncoderDecoderTransformerOutput' ::
    forall encoderOutput inputPaddingMask.
    { forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> encoderOutput
sedtEncoderOutput' :: encoderOutput,
      forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> inputPaddingMask
sedtInputPaddingMask' :: inputPaddingMask
    } ->
    SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask
  deriving stock (SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall encoderOutput inputPaddingMask.
(Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
/= :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
$c/= :: forall encoderOutput inputPaddingMask.
(Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
== :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
$c== :: forall encoderOutput inputPaddingMask.
(Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
Eq, SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {encoderOutput} {inputPaddingMask}.
(Ord encoderOutput, Ord inputPaddingMask) =>
Eq
  (SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask)
forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Ordering
forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
min :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
$cmin :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
max :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
$cmax :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
>= :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
$c>= :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
> :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
$c> :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
<= :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
$c<= :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
< :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
$c< :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Bool
compare :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Ordering
$ccompare :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> ShowS
forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput'
   encoderOutput inputPaddingMask]
-> ShowS
forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> String
showList :: [SimplifiedEncoderDecoderTransformerOutput'
   encoderOutput inputPaddingMask]
-> ShowS
$cshowList :: forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput'
   encoderOutput inputPaddingMask]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> String
$cshow :: forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> ShowS
$cshowsPrec :: forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall encoderOutput inputPaddingMask x.
Rep
  (SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask)
  x
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
forall encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> Rep
     (SimplifiedEncoderDecoderTransformerOutput'
        encoderOutput inputPaddingMask)
     x
$cto :: forall encoderOutput inputPaddingMask x.
Rep
  (SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask)
  x
-> SimplifiedEncoderDecoderTransformerOutput'
     encoderOutput inputPaddingMask
$cfrom :: forall encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> Rep
     (SimplifiedEncoderDecoderTransformerOutput'
        encoderOutput inputPaddingMask)
     x
Generic)

data SimplifiedEncoderDecoderTransformerTrainingOutput loss where
  SimplifiedEncoderDecoderTransformerTrainingOutput ::
    forall loss.
    { forall loss.
SimplifiedEncoderDecoderTransformerTrainingOutput loss -> loss
sedtLoss :: loss
    } ->
    SimplifiedEncoderDecoderTransformerTrainingOutput loss
  deriving stock (SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
forall loss.
Eq loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c/= :: forall loss.
Eq loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
== :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c== :: forall loss.
Eq loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
Eq, SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {loss}.
Ord loss =>
Eq (SimplifiedEncoderDecoderTransformerTrainingOutput loss)
forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
min :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
$cmin :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
max :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
$cmax :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
>= :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c>= :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
> :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c> :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
<= :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c<= :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
< :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c< :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
compare :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
$ccompare :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
forall loss.
Show loss =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
forall loss.
Show loss =>
[SimplifiedEncoderDecoderTransformerTrainingOutput loss] -> ShowS
forall loss.
Show loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SimplifiedEncoderDecoderTransformerTrainingOutput loss] -> ShowS
$cshowList :: forall loss.
Show loss =>
[SimplifiedEncoderDecoderTransformerTrainingOutput loss] -> ShowS
show :: SimplifiedEncoderDecoderTransformerTrainingOutput loss -> String
$cshow :: forall loss.
Show loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss -> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
$cshowsPrec :: forall loss.
Show loss =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall loss x.
Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
forall loss x.
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
$cto :: forall loss x.
Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
$cfrom :: forall loss x.
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
Generic)

-- | Input data type for use with an encoder-decoder transformer.
-- Use this for inference.
data EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask where
  EncoderDecoderTransformerGenerationInput ::
    forall decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask.
    { forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> decoderInput
edtGenerationDecoderInput :: decoderInput,
      forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> encoderOutput
edtGenerationEncoderOutput :: encoderOutput,
      forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> decoderPos
edtGenerationDecoderPos :: decoderPos,
      forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> decoderAttentionMask
edtGenerationDecoderAttentionMask :: decoderAttentionMask,
      forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> crossAttentionMask
edtGenerationCrossAttentionMask :: crossAttentionMask
    } ->
    EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask
  deriving stock (EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Eq decoderInput, Eq encoderOutput, Eq decoderPos,
 Eq decoderAttentionMask, Eq crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
/= :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c/= :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Eq decoderInput, Eq encoderOutput, Eq decoderPos,
 Eq decoderAttentionMask, Eq crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
== :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c== :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Eq decoderInput, Eq encoderOutput, Eq decoderPos,
 Eq decoderAttentionMask, Eq crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
Eq, EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderInput} {encoderOutput} {decoderPos}
       {decoderAttentionMask} {crossAttentionMask}.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
Eq
  (EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask)
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Ordering
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
min :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
$cmin :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
max :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
$cmax :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
>= :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c>= :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
> :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c> :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
<= :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c<= :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
< :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
$c< :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Bool
compare :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Ordering
$ccompare :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
 Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Ordering
Ord, Int
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
 Show decoderAttentionMask, Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> ShowS
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
 Show decoderAttentionMask, Show crossAttentionMask) =>
[EncoderDecoderTransformerGenerationInput
   decoderInput
   encoderOutput
   decoderPos
   decoderAttentionMask
   crossAttentionMask]
-> ShowS
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
 Show decoderAttentionMask, Show crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> String
showList :: [EncoderDecoderTransformerGenerationInput
   decoderInput
   encoderOutput
   decoderPos
   decoderAttentionMask
   crossAttentionMask]
-> ShowS
$cshowList :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
 Show decoderAttentionMask, Show crossAttentionMask) =>
[EncoderDecoderTransformerGenerationInput
   decoderInput
   encoderOutput
   decoderPos
   decoderAttentionMask
   crossAttentionMask]
-> ShowS
show :: EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> String
$cshow :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
 Show decoderAttentionMask, Show crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> String
showsPrec :: Int
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> ShowS
$cshowsPrec :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
 Show decoderAttentionMask, Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask x.
Rep
  (EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask)
  x
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask x.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> Rep
     (EncoderDecoderTransformerGenerationInput
        decoderInput
        encoderOutput
        decoderPos
        decoderAttentionMask
        crossAttentionMask)
     x
$cto :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask x.
Rep
  (EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask)
  x
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
$cfrom :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask x.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> Rep
     (EncoderDecoderTransformerGenerationInput
        decoderInput
        encoderOutput
        decoderPos
        decoderAttentionMask
        crossAttentionMask)
     x
Generic)

data SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask where
  SimplifiedEncoderDecoderTransformerGenerationInput ::
    forall decoderInput encoderOutput inputPaddingMask.
    { forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> decoderInput
sedtGenerationDecoderInput :: decoderInput,
      forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> encoderOutput
sedtGenerationEncoderOutput :: encoderOutput,
      forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> inputPaddingMask
sedtGenerationInputPaddingMask :: inputPaddingMask
    } ->
    SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask
  deriving stock (SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderInput encoderOutput inputPaddingMask.
(Eq decoderInput, Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
/= :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
$c/= :: forall decoderInput encoderOutput inputPaddingMask.
(Eq decoderInput, Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
== :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
$c== :: forall decoderInput encoderOutput inputPaddingMask.
(Eq decoderInput, Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
Eq, SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderInput} {encoderOutput} {inputPaddingMask}.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
Eq
  (SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask)
forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Ordering
forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
min :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
$cmin :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
max :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
$cmax :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
>= :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
$c>= :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
> :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
$c> :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
<= :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
$c<= :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
< :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
$c< :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Bool
compare :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Ordering
$ccompare :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> ShowS
forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerGenerationInput
   decoderInput encoderOutput inputPaddingMask]
-> ShowS
forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> String
showList :: [SimplifiedEncoderDecoderTransformerGenerationInput
   decoderInput encoderOutput inputPaddingMask]
-> ShowS
$cshowList :: forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerGenerationInput
   decoderInput encoderOutput inputPaddingMask]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> String
$cshow :: forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> ShowS
$cshowsPrec :: forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderInput encoderOutput inputPaddingMask x.
Rep
  (SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask)
  x
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
forall decoderInput encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> Rep
     (SimplifiedEncoderDecoderTransformerGenerationInput
        decoderInput encoderOutput inputPaddingMask)
     x
$cto :: forall decoderInput encoderOutput inputPaddingMask x.
Rep
  (SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask)
  x
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
$cfrom :: forall decoderInput encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> Rep
     (SimplifiedEncoderDecoderTransformerGenerationInput
        decoderInput encoderOutput inputPaddingMask)
     x
Generic)

-- | 'HasForward' instance for encoder-decoder transformers with optional head.
--
-- @
--     ┌───────┐  ┌─────┐  ┌───────────────┐  ┌──────────────┐  ┌────────────┐  ┌──────────────────────┐  ┌────────────────────┐
--     │ input │  │ pos │  │ attentionMask │  │ decoderInput │  │ decoderPos │  │ decoderAttentionMask │  │ crossAttentionMask │
--     └───┬───┘  └──┬──┘  └──────┬────────┘  └──────┬───────┘  └─────┬──────┘  └──────────┬───────────┘  └─────────┬──────────┘
--         │         │            │                  │                │                    │                        │
--         ▼         │            │                  │                │                    │                        │
-- edtSharedEmbedding│            │                  │                │                    │                        │
--         ▼         │            │                  │                │                    │                        │
--   (embedScaling)  │            │                  │                │                    │                        │
--         ▼         │            │                  │                │                    │                        │
--     edtEncoder◄───┘◄───────────┘                  ▼                │                    │                        │
--         │                                 edtSharedEmbedding       │                    │                        │
--         │                                         ▼                │                    │                        │
--         │                                   (embedScaling)         │                    │                        │
--         │                                         ▼                │                    │                        │
--         ├────────────────────────────────────►edtDecoder◄──────────┘◄───────────────────┘◄───────────────────────┘
--         │                                         ▼
--         │                                     (edtHead)
--         │                                         │
--         ▼                                         ▼
-- ┌───────────────┐                         ┌───────────────┐
-- │ encoderOutput │                         │ decoderOutput │
-- └───────────────┘                         └───────────────┘
-- @
instance
  ( HasForward
      (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
      (EncoderDecoderTransformerInput' input pos attentionMask)
      generatorDevice
      (EncoderDecoderTransformerOutput' encoderOutput)
      generatorDevice',
    HasForward
      (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
      (EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask)
      generatorDevice'
      (EncoderDecoderTransformerOutput headOutput encoderOutput)
      generatorOutputDevice
  ) =>
  HasForward
    (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
    (EncoderDecoderTransformerInput input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask)
    generatorDevice
    (EncoderDecoderTransformerOutput headOutput encoderOutput)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerInput
     input
     decoderInput
     pos
     decoderPos
     attentionMask
     decoderAttentionMask
     crossAttentionMask
-> Generator generatorDevice
-> m (EncoderDecoderTransformerOutput headOutput encoderOutput,
      Generator generatorOutputDevice)
forward GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
model EncoderDecoderTransformerInput {input
pos
attentionMask
decoderInput
decoderPos
decoderAttentionMask
crossAttentionMask
edtCrossAttentionMask :: crossAttentionMask
edtDecoderAttentionMask :: decoderAttentionMask
edtAttentionMask :: attentionMask
edtDecoderPos :: decoderPos
edtPos :: pos
edtDecoderInput :: decoderInput
edtInput :: input
edtCrossAttentionMask :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> crossAttentionMask
edtDecoderAttentionMask :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> decoderAttentionMask
edtAttentionMask :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> attentionMask
edtDecoderPos :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> decoderPos
edtPos :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> pos
edtDecoderInput :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> decoderInput
edtInput :: forall input decoderInput pos decoderPos attentionMask
       decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
  input
  decoderInput
  pos
  decoderPos
  attentionMask
  decoderAttentionMask
  crossAttentionMask
-> input
..} =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT
        ( forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
            GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
model
            EncoderDecoderTransformerInput'
              { edtInput' :: input
edtInput' = input
edtInput,
                edtPos' :: pos
edtPos' = pos
edtPos,
                edtAttentionMask' :: attentionMask
edtAttentionMask' = attentionMask
edtAttentionMask
              }
        )
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \EncoderDecoderTransformerOutput' {encoderOutput
edtEncoderOutput' :: encoderOutput
edtEncoderOutput' :: forall encoderOutput.
EncoderDecoderTransformerOutput' encoderOutput -> encoderOutput
..} ->
                 forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT
                   ( forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
                       GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
model
                       EncoderDecoderTransformerGenerationInput
                         { edtGenerationDecoderInput :: decoderInput
edtGenerationDecoderInput = decoderInput
edtDecoderInput,
                           edtGenerationEncoderOutput :: encoderOutput
edtGenerationEncoderOutput = encoderOutput
edtEncoderOutput',
                           edtGenerationDecoderPos :: decoderPos
edtGenerationDecoderPos = decoderPos
edtDecoderPos,
                           edtGenerationDecoderAttentionMask :: decoderAttentionMask
edtGenerationDecoderAttentionMask = decoderAttentionMask
edtDecoderAttentionMask,
                           edtGenerationCrossAttentionMask :: crossAttentionMask
edtGenerationCrossAttentionMask = crossAttentionMask
edtCrossAttentionMask
                         }
                   )
             )

instance
  ( HasForward
      sharedEmbedding
      input
      generatorDevice
      embeddingOutput
      embeddingGeneratorOutputDevice,
    embeddingOutput ~ Tensor requiresGradient' layout' device' dataType' shape',
    HasForward
      encoder
      (embeddingOutput, pos, attentionMask)
      embeddingGeneratorOutputDevice
      encoderOutput
      generatorOutputDevice
  ) =>
  HasForward
    (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
    (EncoderDecoderTransformerInput' input pos attentionMask)
    generatorDevice
    (EncoderDecoderTransformerOutput' encoderOutput)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Generator generatorDevice
-> m (EncoderDecoderTransformerOutput' encoderOutput,
      Generator generatorOutputDevice)
forward GEncoderDecoderTransformer {sharedEmbedding
encoder
decoder
head
SDim inputEmbedDim
EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling :: EncoderDecoderTransformerHasEmbedScaling
edtHead :: head
edtSharedEmbedding :: sharedEmbedding
edtDecoder :: decoder
edtEncoder :: encoder
edtInputEmbedDim :: SDim inputEmbedDim
edtEmbedScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerHasEmbedScaling
edtHead :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> head
edtSharedEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> sharedEmbedding
edtDecoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> decoder
edtEncoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> encoder
edtInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> SDim inputEmbedDim
..} EncoderDecoderTransformerInput' {input
pos
attentionMask
edtAttentionMask' :: attentionMask
edtPos' :: pos
edtInput' :: input
edtAttentionMask' :: forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask
-> attentionMask
edtPos' :: forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> pos
edtInput' :: forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> input
..} =
    let Double
scaling :: Double = forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
edtInputEmbedDim
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn input
edtInput'
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward sharedEmbedding
edtSharedEmbedding
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
                    EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
                )
                EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor requiresGradient' layout' device' dataType' shape'
input' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoder
edtEncoder (Tensor requiresGradient' layout' device' dataType' shape'
input', pos
edtPos', attentionMask
edtAttentionMask'))
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\encoderOutput
encoderOutput -> forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (forall encoderOutput.
encoderOutput -> EncoderDecoderTransformerOutput' encoderOutput
EncoderDecoderTransformerOutput' encoderOutput
encoderOutput))

-- | 'HasForward' instance for encoder-decoder transformers with optional head.
-- Use this instance for sequence generation once the encoder's output is available.
--
-- @
-- ┌───────────────┐  ┌──────────────┐  ┌────────────┐  ┌──────────────────────┐  ┌────────────────────┐
-- │ encoderOutput │  │ decoderInput │  │ decoderPos │  │ decoderAttentionMask │  │ crossAttentionMask │
-- └───────┬───────┘  └───────┬──────┘  └──────┬─────┘  └───────────┬──────────┘  └──────────┬─────────┘
--         │                  │                │                    │                        │
--         │                  ▼                │                    │                        │
--         │          edtSharedEmbedding       │                    │                        │
--         │                  ▼                │                    │                        │
--         │            (embedScaling)         │                    │                        │
--         │                  ▼                │                    │                        │
--         ├────────────►edtDecoder◄───────────┘◄───────────────────┘◄───────────────────────┘
--         │                  │
--         │              (edtHead)
--         │                  │
--         ▼                  ▼
-- ┌───────────────┐  ┌───────────────┐
-- │ encoderOutput │  │ decoderOutput │
-- └───────────────┘  └───────────────┘
-- @
instance
  ( HasForward
      sharedEmbedding
      decoderInput
      generatorDevice
      embeddingOutput'
      embeddingGeneratorOutputDevice',
    embeddingOutput' ~ Tensor requiresGradient'' layout'' device'' dataType'' shape'',
    HasForward
      decoder
      ( embeddingOutput',
        encoderOutput,
        decoderPos,
        decoderAttentionMask,
        crossAttentionMask
      )
      embeddingGeneratorOutputDevice'
      decoderOutput
      decoderGeneratorOutputDevice,
    HasForward
      head
      decoderOutput
      decoderGeneratorOutputDevice
      headOutput
      generatorOutputDevice
  ) =>
  HasForward
    (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
    (EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask)
    generatorDevice
    (EncoderDecoderTransformerOutput headOutput encoderOutput)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
-> Generator generatorDevice
-> m (EncoderDecoderTransformerOutput headOutput encoderOutput,
      Generator generatorOutputDevice)
forward GEncoderDecoderTransformer {sharedEmbedding
decoder
head
encoder
SDim inputEmbedDim
EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling :: EncoderDecoderTransformerHasEmbedScaling
edtHead :: head
edtSharedEmbedding :: sharedEmbedding
edtDecoder :: decoder
edtEncoder :: encoder
edtInputEmbedDim :: SDim inputEmbedDim
edtEmbedScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerHasEmbedScaling
edtHead :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> head
edtSharedEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> sharedEmbedding
edtDecoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> decoder
edtEncoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> encoder
edtInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       decoder sharedEmbedding head.
GEncoderDecoderTransformer
  inputEmbedDim encoder decoder sharedEmbedding head
-> SDim inputEmbedDim
..} EncoderDecoderTransformerGenerationInput {decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
edtGenerationCrossAttentionMask :: crossAttentionMask
edtGenerationDecoderAttentionMask :: decoderAttentionMask
edtGenerationDecoderPos :: decoderPos
edtGenerationEncoderOutput :: encoderOutput
edtGenerationDecoderInput :: decoderInput
edtGenerationCrossAttentionMask :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> crossAttentionMask
edtGenerationDecoderAttentionMask :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> decoderAttentionMask
edtGenerationDecoderPos :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> decoderPos
edtGenerationEncoderOutput :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> encoderOutput
edtGenerationDecoderInput :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
EncoderDecoderTransformerGenerationInput
  decoderInput
  encoderOutput
  decoderPos
  decoderAttentionMask
  crossAttentionMask
-> decoderInput
..} =
    let Double
scaling :: Double = forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
edtInputEmbedDim
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn decoderInput
edtGenerationDecoderInput
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward sharedEmbedding
edtSharedEmbedding
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
                    EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
                )
                EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor requiresGradient'' layout'' device'' dataType'' shape''
decoderInput' ->
                     forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward decoder
edtDecoder (Tensor requiresGradient'' layout'' device'' dataType'' shape''
decoderInput', encoderOutput
edtGenerationEncoderOutput, decoderPos
edtGenerationDecoderPos, decoderAttentionMask
edtGenerationDecoderAttentionMask, crossAttentionMask
edtGenerationCrossAttentionMask)
                 )
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward head
edtHead
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \headOutput
decoderOutput -> forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (forall decoderOutput encoderOutput.
decoderOutput
-> encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
EncoderDecoderTransformerOutput headOutput
decoderOutput encoderOutput
edtGenerationEncoderOutput)

-- | 'HasForward' instance for simplified encoder-decoder models.

-- This instance shifts decoder inputs by one token to the right by adding
-- a model-specific sequence initialization token at the beginning.
instance
  ( HasForward
      (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
      (SimplifiedEncoderDecoderTransformerInput' input)
      generatorDevice
      (SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask)
      generatorDevice',
    HasForward
      (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
      (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask)
      generatorDevice'
      (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
      generatorOutputDevice
  ) =>
  HasForward
    (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
    (SimplifiedEncoderDecoderTransformerInput input decoderInput)
    generatorDevice
    (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerOutput
        decoderOutput encoderOutput decoderInput inputPaddingMask,
      Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
model SimplifiedEncoderDecoderTransformerInput {input
decoderInput
sedtDecoderInput :: decoderInput
sedtInput :: input
sedtDecoderInput :: forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> decoderInput
sedtInput :: forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> input
..} =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
model SimplifiedEncoderDecoderTransformerInput' {sedtInput' :: input
sedtInput' = input
sedtInput})
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \SimplifiedEncoderDecoderTransformerOutput' {encoderOutput
inputPaddingMask
sedtInputPaddingMask' :: inputPaddingMask
sedtEncoderOutput' :: encoderOutput
sedtInputPaddingMask' :: forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> inputPaddingMask
sedtEncoderOutput' :: forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
  encoderOutput inputPaddingMask
-> encoderOutput
..} ->
                 forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$
                   SimplifiedEncoderDecoderTransformerGenerationInput
                     { sedtGenerationDecoderInput :: decoderInput
sedtGenerationDecoderInput = decoderInput
sedtDecoderInput,
                       sedtGenerationEncoderOutput :: encoderOutput
sedtGenerationEncoderOutput = encoderOutput
sedtEncoderOutput',
                       sedtGenerationInputPaddingMask :: inputPaddingMask
sedtGenerationInputPaddingMask = inputPaddingMask
sedtInputPaddingMask'
                     }
             )
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
model

instance
  ( HasForward
      mkPaddingMask
      input
      generatorDevice
      inputPaddingMask
      generatorDevice,
    HasForward
      mkAttentionMask
      inputPaddingMask
      generatorDevice
      attentionMask
      generatorDevice,
    HasForward
      mkPos
      input
      generatorDevice
      pos
      generatorDevice,
    HasForward
      model
      (EncoderDecoderTransformerInput' input pos attentionMask)
      generatorDevice
      (EncoderDecoderTransformerOutput' encoderOutput)
      generatorOutputDevice
  ) =>
  HasForward
    (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
    (SimplifiedEncoderDecoderTransformerInput' input)
    generatorDevice
    (SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerInput' input
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerOutput'
        encoderOutput inputPaddingMask,
      Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer {mkPaddingMask
mkAttentionMask
mkPos
model
mkDecoderPos
mkCrossAttentionMask
mkDecoderAttentionMask
ShiftRight Int
sedtMkDecoderAttentionMask :: mkDecoderAttentionMask
sedtMkCrossAttentionMask :: mkCrossAttentionMask
sedtMkAttentionMask :: mkAttentionMask
sedtMkPaddingMask :: mkPaddingMask
sedtMkDecoderPos :: mkDecoderPos
sedtMkPos :: mkPos
sedtPaddingMaskShift :: ShiftRight Int
sedtDecoderInputShift :: ShiftRight Int
sedtModel :: model
sedtMkDecoderAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkDecoderAttentionMask
sedtMkCrossAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkCrossAttentionMask
sedtMkAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkAttentionMask
sedtMkPaddingMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkPaddingMask
sedtMkDecoderPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkDecoderPos
sedtMkPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkPos
sedtPaddingMaskShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> ShiftRight Int
sedtDecoderInputShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> ShiftRight Int
sedtModel :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> model
..} SimplifiedEncoderDecoderTransformerInput' {input
sedtInput' :: input
sedtInput' :: forall input.
SimplifiedEncoderDecoderTransformerInput' input -> input
..} =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPaddingMask
sedtMkPaddingMask input
sedtInput')
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \inputPaddingMask
inputPaddingMask ->
                 let pos :: IxStateT
  m (Generator generatorDevice) (Generator generatorDevice) pos
pos = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPos
sedtMkPos forall a b. (a -> b) -> a -> b
$ input
sedtInput'
                     attentionMask :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  attentionMask
attentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkAttentionMask
sedtMkAttentionMask forall a b. (a -> b) -> a -> b
$ inputPaddingMask
inputPaddingMask
                  in forall input pos attentionMask.
input
-> pos
-> attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
EncoderDecoderTransformerInput'
                       input
sedtInput'
                       forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m (Generator generatorDevice) (Generator generatorDevice) pos
pos
                       forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  attentionMask
attentionMask
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
sedtModel
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \EncoderDecoderTransformerOutput' {encoderOutput
edtEncoderOutput' :: encoderOutput
edtEncoderOutput' :: forall encoderOutput.
EncoderDecoderTransformerOutput' encoderOutput -> encoderOutput
..} ->
                                forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$
                                  SimplifiedEncoderDecoderTransformerOutput'
                                    { sedtEncoderOutput' :: encoderOutput
sedtEncoderOutput' = encoderOutput
edtEncoderOutput',
                                      sedtInputPaddingMask' :: inputPaddingMask
sedtInputPaddingMask' = inputPaddingMask
inputPaddingMask
                                    }
                            )
             )

-- | 'HasForward' instance for simplified encoder-decoder models.
-- Use this instance for sequence generation once the encoder's output is available.

-- This instance shifts decoder inputs by one token to the right by adding
-- a model-specific sequence initialization token at the beginning.
instance
  ( HasForward
      mkPaddingMask
      decoderInput
      generatorDevice
      decoderInputPaddingMask
      generatorDevice,
    HasForward
      mkCrossAttentionMask
      (rightShiftedDecoderInput, inputPaddingMask)
      generatorDevice
      crossAttentionMask
      generatorDevice,
    HasForward
      mkDecoderAttentionMask
      rightShiftedDecoderInputPaddingMask
      generatorDevice
      decoderAttentionMask
      generatorDevice,
    HasForward
      (ShiftRight Int)
      decoderInput
      generatorDevice
      rightShiftedDecoderInput
      generatorDevice,
    HasForward
      (ShiftRight Int)
      decoderInputPaddingMask
      generatorDevice
      rightShiftedDecoderInputPaddingMask
      generatorDevice,
    HasForward
      mkDecoderPos
      rightShiftedDecoderInput
      generatorDevice
      decoderPos
      generatorDevice,
    HasForward
      model
      (EncoderDecoderTransformerGenerationInput rightShiftedDecoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask)
      generatorDevice
      (EncoderDecoderTransformerOutput decoderOutput encoderOutput)
      generatorOutputDevice
  ) =>
  HasForward
    (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
    (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask)
    generatorDevice
    (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerOutput
        decoderOutput encoderOutput decoderInput inputPaddingMask,
      Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer {mkPaddingMask
mkCrossAttentionMask
mkDecoderAttentionMask
mkDecoderPos
model
mkPos
mkAttentionMask
ShiftRight Int
sedtMkDecoderAttentionMask :: mkDecoderAttentionMask
sedtMkCrossAttentionMask :: mkCrossAttentionMask
sedtMkAttentionMask :: mkAttentionMask
sedtMkPaddingMask :: mkPaddingMask
sedtMkDecoderPos :: mkDecoderPos
sedtMkPos :: mkPos
sedtPaddingMaskShift :: ShiftRight Int
sedtDecoderInputShift :: ShiftRight Int
sedtModel :: model
sedtMkDecoderAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkDecoderAttentionMask
sedtMkCrossAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkCrossAttentionMask
sedtMkAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkAttentionMask
sedtMkPaddingMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkPaddingMask
sedtMkDecoderPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkDecoderPos
sedtMkPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> mkPos
sedtPaddingMaskShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> ShiftRight Int
sedtDecoderInputShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> ShiftRight Int
sedtModel :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> model
..} SimplifiedEncoderDecoderTransformerGenerationInput {decoderInput
inputPaddingMask
encoderOutput
sedtGenerationInputPaddingMask :: inputPaddingMask
sedtGenerationEncoderOutput :: encoderOutput
sedtGenerationDecoderInput :: decoderInput
sedtGenerationInputPaddingMask :: forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> inputPaddingMask
sedtGenerationEncoderOutput :: forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> encoderOutput
sedtGenerationDecoderInput :: forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
  decoderInput encoderOutput inputPaddingMask
-> decoderInput
..} =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      ( let rightShiftedDecoderInput :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  rightShiftedDecoderInput
rightShiftedDecoderInput = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward ShiftRight Int
sedtDecoderInputShift forall a b. (a -> b) -> a -> b
$ decoderInput
sedtGenerationDecoderInput
            rightShiftedDecoderInputPaddingMask :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask =
              forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn decoderInput
sedtGenerationDecoderInput
                forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPaddingMask
sedtMkPaddingMask
                forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward ShiftRight Int
sedtPaddingMaskShift
         in (,)
              forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  rightShiftedDecoderInput
rightShiftedDecoderInput
              forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask
      )
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(rightShiftedDecoderInput
rightShiftedDecoderInput, rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask) ->
                 let decoderPos :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  decoderPos
decoderPos = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkDecoderPos
sedtMkDecoderPos forall a b. (a -> b) -> a -> b
$ rightShiftedDecoderInput
rightShiftedDecoderInput
                     crossAttentionMask :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  crossAttentionMask
crossAttentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkCrossAttentionMask
sedtMkCrossAttentionMask forall a b. (a -> b) -> a -> b
$ (rightShiftedDecoderInput
rightShiftedDecoderInput, inputPaddingMask
sedtGenerationInputPaddingMask)
                     decoderAttentionMask :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  decoderAttentionMask
decoderAttentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkDecoderAttentionMask
sedtMkDecoderAttentionMask forall a b. (a -> b) -> a -> b
$ rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask
                  in ( forall decoderInput encoderOutput decoderPos decoderAttentionMask
       crossAttentionMask.
decoderInput
-> encoderOutput
-> decoderPos
-> decoderAttentionMask
-> crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
     decoderInput
     encoderOutput
     decoderPos
     decoderAttentionMask
     crossAttentionMask
EncoderDecoderTransformerGenerationInput
                         rightShiftedDecoderInput
rightShiftedDecoderInput
                         encoderOutput
sedtGenerationEncoderOutput
                         forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  decoderPos
decoderPos
                         forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  decoderAttentionMask
decoderAttentionMask
                         forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  crossAttentionMask
crossAttentionMask
                     )
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
sedtModel
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(EncoderDecoderTransformerOutput decoderOutput
decoderOutput encoderOutput
encoderOutput) ->
                                forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$ forall decoderOutput encoderOutput decoderInput inputPaddingMask.
decoderOutput
-> encoderOutput
-> decoderInput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
SimplifiedEncoderDecoderTransformerOutput decoderOutput
decoderOutput encoderOutput
encoderOutput decoderInput
sedtGenerationDecoderInput inputPaddingMask
sedtGenerationInputPaddingMask
                            )
             )

instance
  ( HasForward
      (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
      (SimplifiedEncoderDecoderTransformerInput input decoderInput)
      generatorDevice
      (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
      generatorOutputDevice,
    decoderInput
      ~ Tensor
          targetGradient
          targetLayout
          targetDevice
          targetDataType
          (IndexDims ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)]) targetShape),
    decoderOutput
      ~ Tensor
          doGradient
          doLayout
          doDevice
          doDataType
          doShape,
    logProbsShape ~ SoftmaxF ('SelectDim ('ByIndex 2)) doShape,
    Catch logProbsShape,
    unsqueezedTargetShape ~ UnsqueezeF ('SelectDim ('ByIndex 2)) targetShape,
    Catch unsqueezedTargetShape,
    gatheredLogProbsShape ~ GatherDimF ('SelectDim ('ByIndex 2)) unsqueezedTargetShape logProbsShape,
    Catch gatheredLogProbsShape,
    Catch (targetDataType <+> 'DataType 'Int64),
    logLikelihoodShape ~ SqueezeDimF ('SelectDim ('ByIndex 2)) gatheredLogProbsShape,
    Catch logLikelihoodShape,
    MeanAllCheckF logLikelihoodShape,
    loss
      ~ Tensor
          (targetGradient <|> doGradient)
          (targetLayout <+> doLayout)
          (targetDevice <+> doDevice)
          doDataType
          ('Shape '[]),
    generatorOutputDevice ~ generatorDevice
  ) =>
  HasForward
    (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
    ( SimplifiedEncoderDecoderTransformerTrainingInput
        input
        (Tensor targetGradient targetLayout targetDevice targetDataType targetShape)
    )
    generatorDevice
    (SimplifiedEncoderDecoderTransformerTrainingOutput loss)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerTrainingInput
     input
     (Tensor
        targetGradient
        targetLayout
        targetDevice
        targetDataType
        targetShape)
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerTrainingOutput loss,
      Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
eot SimplifiedEncoderDecoderTransformerTrainingInput {input
Tensor
  targetGradient targetLayout targetDevice targetDataType targetShape
sedtTarget :: Tensor
  targetGradient targetLayout targetDevice targetDataType targetShape
sedtTrainingInput :: input
sedtTarget :: forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> target
sedtTrainingInput :: forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> input
..} =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  targetGradient targetLayout targetDevice targetDataType targetShape
sedtTarget
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (indices :: Indices [IndexType (Index Nat)])
       (requiresGradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> SIndices indices
-> m (Tensor
        requiresGradient layout device dataType (IndexDims indices shape))
! forall (indexTypes :: [IndexType (Index Nat)]).
SList indexTypes -> SIndices ('Indices indexTypes)
SIndices (forall a. SIndexType 'SliceAll
SSliceAll forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a (n :: a). Sing n -> SIndexType ('SliceUpTo n)
SSliceUpTo (forall (index1 :: Nat).
KnownNat index1 =>
SIndex ('NegativeIndex index1)
SNegativeIndex @1) forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil))
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor
  targetGradient
  targetLayout
  targetDevice
  targetDataType
  (IndexDims
     ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
     targetShape)
sedtDecoderInput -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
  model
  mkPos
  mkDecoderPos
  mkPaddingMask
  mkAttentionMask
  mkCrossAttentionMask
  mkDecoderAttentionMask
eot forall a b. (a -> b) -> a -> b
$ SimplifiedEncoderDecoderTransformerInput {sedtInput :: input
sedtInput = input
sedtTrainingInput, Tensor
  targetGradient
  targetLayout
  targetDevice
  targetDataType
  (IndexDims
     ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
     targetShape)
sedtDecoderInput :: Tensor
  targetGradient
  targetLayout
  targetDevice
  targetDataType
  (IndexDims
     ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
     targetShape)
sedtDecoderInput :: Tensor
  targetGradient
  targetLayout
  targetDevice
  targetDataType
  (IndexDims
     ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
     targetShape)
sedtDecoderInput})
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderOutput
sedtDecoderOutput
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \Tensor doGradient doLayout doDevice doDataType doShape
logits -> do
                Tensor
  doGradient
  doLayout
  doDevice
  doDataType
  (SoftmaxF ('SelectDim ('ByIndex 2)) doShape)
logProbs <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor doGradient doLayout doDevice doDataType doShape
logits
                Tensor
  targetGradient
  targetLayout
  targetDevice
  targetDataType
  unsqueezedTargetShape
target' <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
  targetGradient targetLayout targetDevice targetDataType targetShape
sedtTarget
                Tensor
  (Or (Gradient RequiresGradient) targetGradient doGradient)
  (Unify (Layout LayoutType) targetLayout doLayout)
  (Unify (Device (DeviceType Nat)) targetDevice doDevice)
  doDataType
  (GatherDimF
     ('SelectDim ('ByIndex 2))
     unsqueezedTargetShape
     (SoftmaxF ('SelectDim ('ByIndex 2)) doShape))
gatheredLogProbs <- forall (selectDim :: SelectDim (By Symbol Nat))
       (indexGradient :: Gradient RequiresGradient)
       (inputGradient :: Gradient RequiresGradient)
       (indexLayout :: Layout LayoutType)
       (inputLayout :: Layout LayoutType)
       (indexDevice :: Device (DeviceType Nat))
       (inputDevice :: Device (DeviceType Nat))
       (indexDataType :: DataType DType) (inputDataType :: DataType DType)
       (indexShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (inputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (outputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(MonadThrow m,
 outputShape ~ GatherDimF selectDim indexShape inputShape,
 Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) =>
SSelectDim selectDim
-> Tensor
     indexGradient indexLayout indexDevice indexDataType indexShape
-> Tensor
     inputGradient inputLayout inputDevice inputDataType inputShape
-> m (Tensor
        (indexGradient <|> inputGradient)
        (indexLayout <+> inputLayout)
        (indexDevice <+> inputDevice)
        inputDataType
        outputShape)
sGatherDim (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
  targetGradient
  targetLayout
  targetDevice
  targetDataType
  unsqueezedTargetShape
target' Tensor
  doGradient
  doLayout
  doDevice
  doDataType
  (SoftmaxF ('SelectDim ('ByIndex 2)) doShape)
logProbs
                Tensor
  (Or (Gradient RequiresGradient) targetGradient doGradient)
  (Unify (Layout LayoutType) targetLayout doLayout)
  (Unify (Device (DeviceType Nat)) targetDevice doDevice)
  doDataType
  (SqueezeDimF
     ('SelectDim ('ByIndex 2))
     (GatherDimF
        ('SelectDim ('ByIndex 2))
        unsqueezedTargetShape
        (SoftmaxF ('SelectDim ('ByIndex 2)) doShape)))
logLikelihood <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SqueezeDimF selectDim shape,
 Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sSqueezeDim (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
  (Or (Gradient RequiresGradient) targetGradient doGradient)
  (Unify (Layout LayoutType) targetLayout doLayout)
  (Unify (Device (DeviceType Nat)) targetDevice doDevice)
  doDataType
  (GatherDimF
     ('SelectDim ('ByIndex 2))
     unsqueezedTargetShape
     (SoftmaxF ('SelectDim ('ByIndex 2)) doShape))
gatheredLogProbs
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MeanAllCheckF shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType ('Shape '[])
meanAll Tensor
  (Or (Gradient RequiresGradient) targetGradient doGradient)
  (Unify (Layout LayoutType) targetLayout doLayout)
  (Unify (Device (DeviceType Nat)) targetDevice doDevice)
  doDataType
  (SqueezeDimF
     ('SelectDim ('ByIndex 2))
     (GatherDimF
        ('SelectDim ('ByIndex 2))
        unsqueezedTargetShape
        (SoftmaxF ('SelectDim ('ByIndex 2)) doShape)))
logLikelihood
            )
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall loss.
loss -> SimplifiedEncoderDecoderTransformerTrainingOutput loss
SimplifiedEncoderDecoderTransformerTrainingOutput