{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}

module Torch.GraduallyTyped.NN.Transformer.GEncoderOnly where

import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (fromSing))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType)
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Sparse (Embedding (..), EmbeddingSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GLMHead (GLMHeadF, lmHeadSpec)
import Torch.GraduallyTyped.NN.Transformer.GTransformer (TransformerEncoderF, transformerEncoderSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (..), STransformerStyle (..), TransformerHead (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked)
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (SNothing))
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient)
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Size (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add, mulScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Prelude hiding (head)

-- | Data type that is used to represent whether the encoder-only transformer model has a scaled embedding.
data EncoderOnlyTransformerHasEmbedScaling
  = EncoderOnlyTransformerWithEmbedScaling
  | EncoderOnlyTransformerWithoutEmbedScaling
  deriving (EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c/= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
== :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c== :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
Eq, Eq EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Ordering
EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
$cmin :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
max :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
$cmax :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
>= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c>= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
> :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c> :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
<= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c<= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
< :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c< :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
compare :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Ordering
$ccompare :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Ordering
Ord, Int -> EncoderOnlyTransformerHasEmbedScaling -> ShowS
[EncoderOnlyTransformerHasEmbedScaling] -> ShowS
EncoderOnlyTransformerHasEmbedScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderOnlyTransformerHasEmbedScaling] -> ShowS
$cshowList :: [EncoderOnlyTransformerHasEmbedScaling] -> ShowS
show :: EncoderOnlyTransformerHasEmbedScaling -> String
$cshow :: EncoderOnlyTransformerHasEmbedScaling -> String
showsPrec :: Int -> EncoderOnlyTransformerHasEmbedScaling -> ShowS
$cshowsPrec :: Int -> EncoderOnlyTransformerHasEmbedScaling -> ShowS
Show, forall x.
Rep EncoderOnlyTransformerHasEmbedScaling x
-> EncoderOnlyTransformerHasEmbedScaling
forall x.
EncoderOnlyTransformerHasEmbedScaling
-> Rep EncoderOnlyTransformerHasEmbedScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep EncoderOnlyTransformerHasEmbedScaling x
-> EncoderOnlyTransformerHasEmbedScaling
$cfrom :: forall x.
EncoderOnlyTransformerHasEmbedScaling
-> Rep EncoderOnlyTransformerHasEmbedScaling x
Generic)

type instance ModelSpec EncoderOnlyTransformerHasEmbedScaling = EncoderOnlyTransformerHasEmbedScaling

instance HasInitialize EncoderOnlyTransformerHasEmbedScaling generatorDevice EncoderOnlyTransformerHasEmbedScaling generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec EncoderOnlyTransformerHasEmbedScaling
-> Generator generatorDevice
-> m (EncoderOnlyTransformerHasEmbedScaling,
      Generator generatorDevice)
initialize ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling, Generator generatorDevice
g)

instance HasStateDict EncoderOnlyTransformerHasEmbedScaling where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec EncoderOnlyTransformerHasEmbedScaling
-> StateDictKey -> m EncoderOnlyTransformerHasEmbedScaling
fromStateDict ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> EncoderOnlyTransformerHasEmbedScaling -> m ()
toStateDict StateDictKey
_ EncoderOnlyTransformerHasEmbedScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Generic encoder-only transformer model.
-- This is a transformer model that only encodes the input, e.g. BERT.
--
-- - @inputEmbedDim@: the dimension of the input embedding.
-- - @encoder@: a transformer encoder.
-- - @encoderEmbedding@: an embedding layer for the input.
-- - @encoderTypeEmbedding@: an embedding layer for the type of the input.
-- - @head@: a head layer for the output.
data
  GEncoderOnlyTransformer
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (encoder :: Type)
    (encoderEmbedding :: Type)
    (encoderTypeEmbedding :: Type)
    (head :: Type)
  where
  GEncoderOnlyTransformer ::
    forall inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head.
    { -- | input embedding dim for scaling
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> SDim inputEmbedDim
eotInputEmbedDim :: SDim inputEmbedDim,
      -- | encoder
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoder
eotEncoder :: encoder,
      -- | encoder embedding
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderEmbedding
eotEmbedding :: encoderEmbedding,
      -- | encoder type embedding
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderTypeEmbedding
eotTypeEmbedding :: encoderTypeEmbedding,
      -- | encoder head
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> head
eotHead :: head,
      -- | encoder embedding scaling
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling :: EncoderOnlyTransformerHasEmbedScaling
    } ->
    GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
  deriving stock (Int
-> GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
 Show head) =>
Int
-> GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
 Show head) =>
[GEncoderOnlyTransformer
   inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head]
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
 Show head) =>
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> String
showList :: [GEncoderOnlyTransformer
   inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head]
-> ShowS
$cshowList :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
 Show head) =>
[GEncoderOnlyTransformer
   inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head]
-> ShowS
show :: GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> String
$cshow :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
 Show head) =>
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> String
showsPrec :: Int
-> GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
$cshowsPrec :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
 Show head) =>
Int
-> GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head x.
Rep
  (GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
  x
-> GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head x.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> Rep
     (GEncoderOnlyTransformer
        inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
     x
$cto :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head x.
Rep
  (GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
  x
-> GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
$cfrom :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head x.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> Rep
     (GEncoderOnlyTransformer
        inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
     x
Generic)

type instance
  ModelSpec (GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head) =
    GEncoderOnlyTransformer inputEmbedDim (ModelSpec encoder) (ModelSpec encoderEmbedding) (ModelSpec encoderTypeEmbedding) (ModelSpec head)

type family
  GEncoderOnlyTransformerF
    (style :: TransformerStyle)
    (transformerHead :: TransformerHead)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (posEncDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat))
    (typeVocabDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  GEncoderOnlyTransformerF style transformerHead numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim typeVocabDim hasDropout =
    GEncoderOnlyTransformer
      inputEmbedDim
      (NamedModel (TransformerEncoderF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout))
      (EOTEmbeddingF style gradient device dataType inputEmbedDim vocabDim)
      (EOTTypeEmbeddingF style gradient device dataType inputEmbedDim typeVocabDim)
      (EOTHeadF style transformerHead gradient device dataType inputEmbedDim vocabDim)

-- | Specifies the embedding layer of the encoder-only transformer model.
type family
  EOTEmbeddingF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  EOTEmbeddingF _ gradient device dataType inputEmbedDim vocabDim =
    NamedModel (Embedding gradient ('Layout 'Dense) device dataType vocabDim inputEmbedDim 'Nothing)

-- | Specifies the type embedding layer of the encoder-only transformer model.
type family
  EOTTypeEmbeddingF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (typeVocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  EOTTypeEmbeddingF 'BERT gradient device dataType inputEmbedDim typeVocabDim =
    NamedModel (Embedding gradient ('Layout 'Dense) device dataType typeVocabDim inputEmbedDim 'Nothing)
  EOTTypeEmbeddingF 'RoBERTa gradient device dataType inputEmbedDim typeVocabDim =
    EOTTypeEmbeddingF 'BERT gradient device dataType inputEmbedDim typeVocabDim

-- | Specifies the head layer of the encoder-only transformer model.
type family
  EOTHeadF
    (style :: TransformerStyle)
    (transformerHead :: TransformerHead)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  EOTHeadF _ 'WithoutHead _ _ _ _ _ =
    ()
  EOTHeadF style 'WithLMHead gradient device dataType inputEmbedDim vocabDim =
    NamedModel (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)

-- | Specifies the parameters of an encoder-only transformer model.
--
-- - @style@: the style of the encoder-only transformer model, e.g. 'SBERT', 'SRoBERTa', etc.
-- - @transformerHead@: the head of the encoder-only transformer model.
-- - @numLayers@: the number of layers of the encoder-only transformer model.
-- - @gradient@: whether to compute the gradient of the model parameters
-- - @device@: the computational device on which the model is allocated.
-- - @dataType@: the data type of the model parameters.
-- - @headDim@: the dimension of all transformer heads in the encoder-only transformer model.
-- - @headEmbedDim@: the dimension of the transformer head embeddings.
-- - @embedDim@: the dimension of the transformer embeddings.
-- - @inputEmbedDim@: the dimension of the input embeddings.
-- - @ffnDim@: the dimension of the feed-forward network.
-- - @posEncDim@: the dimension of the positional embeddings.
-- - @vocabDim@: the dimension of the vocabulary.
-- - @typeVocabDim@: the dimension of the type vocabulary.
-- - @dropoutP@: the dropout rate.
-- - @eps@: the epsilon value for numerical stability of the layer normalization.
encoderOnlyTransformerSpec ::
  forall style transformerHead numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim typeVocabDim hasDropout.
  STransformerStyle style ->
  STransformerHead transformerHead ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim inputEmbedDim ->
  SDim ffnDim ->
  SDim posEncDim ->
  SDim vocabDim ->
  SDim typeVocabDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (GEncoderOnlyTransformerF style transformerHead numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim typeVocabDim hasDropout)
encoderOnlyTransformerSpec :: forall (style :: TransformerStyle)
       (transformerHead :: TransformerHead) (numLayers :: Nat)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (posEncDim :: Dim (Name Symbol) (Size Nat))
       (vocabDim :: Dim (Name Symbol) (Size Nat))
       (typeVocabDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SDim typeVocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GEncoderOnlyTransformerF
        style
        transformerHead
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        vocabDim
        typeVocabDim
        hasDropout)
encoderOnlyTransformerSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SDim vocabDim
vocabDim SDim typeVocabDim
typeVocabDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let encoderSpec :: STransformerStyle style
-> NamedModel
     (GTransformer
        (ModelSpec
           (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
        (ModelSpec
           (TERelPosEncF style gradient device dataType headDim posEncDim))
        (ModelSpec
           (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEInitialDropoutF style hasDropout))
        (NamedModel
           (GTransformerStack
              (VectorSpec
                 numLayers
                 (GTransformerBlock
                    (NamedModel
                       (GSelfAttention
                          (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (SADropoutF style hasDropout)
                          (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    ()
                    (NamedModel
                       (GTransformerFeedForwardNetwork
                          (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                          (FFNInputTransformationF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNActivationF style)
                          (FFNActivationDropoutF style hasDropout)
                          (FFNOutputProjectionF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNOutputDropoutF style hasDropout)
                          (FFNOutputLayerNormF
                             style gradient device dataType inputEmbedDim)))))))
        (ModelSpec
           (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
      encoderSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
      encoderSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
      encoderSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
      encoderSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
      encoderSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"bert." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'BERT
SBERT
      encoderSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"roberta." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'RoBERTa
SRoBERTa
      encoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      embeddingSpec :: STransformerStyle style
-> NamedModel
     (EmbeddingSpec
        gradient
        ('Layout 'Dense)
        device
        dataType
        vocabDim
        inputEmbedDim
        'Nothing)
embeddingSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
      embeddingSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
      embeddingSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
      embeddingSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
      embeddingSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
      embeddingSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"bert.embeddings.word_embeddings." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
embeddingSpec'
      embeddingSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"roberta.embeddings.word_embeddings." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
embeddingSpec'
      embeddingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      typeEmbeddingSpec :: STransformerStyle style
-> ModelSpec
     (EOTTypeEmbeddingF
        style gradient device dataType inputEmbedDim typeVocabDim)
typeEmbeddingSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
      typeEmbeddingSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
      typeEmbeddingSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
      typeEmbeddingSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
      typeEmbeddingSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
      typeEmbeddingSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"bert.embeddings.token_type_embeddings." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  typeVocabDim
  inputEmbedDim
  'Nothing
typeEmbeddingSpec'
      typeEmbeddingSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"roberta.embeddings.token_type_embeddings." EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  typeVocabDim
  inputEmbedDim
  'Nothing
typeEmbeddingSpec'
      typeEmbeddingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      headSpec :: STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
     (EOTHeadF
        style
        transformerHead
        gradient
        device
        dataType
        inputEmbedDim
        vocabDim)
headSpec STransformerStyle style
ST5 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      headSpec STransformerStyle style
SByT5 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      headSpec STransformerStyle style
SBART STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      headSpec STransformerStyle style
SMBART STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      headSpec STransformerStyle style
SPegasus STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      headSpec STransformerStyle style
SBERT STransformerHead transformerHead
SWithoutHead = ()
      headSpec STransformerStyle style
SBERT STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"cls.predictions." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'BERT
SBERT
      headSpec STransformerStyle style
SRoBERTa STransformerHead transformerHead
SWithoutHead = ()
      headSpec STransformerStyle style
SRoBERTa STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'RoBERTa
SRoBERTa
      headSpec STransformerStyle style
SGPT2 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
      embedScalingSpec :: STransformerStyle style -> EncoderOnlyTransformerHasEmbedScaling
      embedScalingSpec :: STransformerStyle style -> EncoderOnlyTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
      embedScalingSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
      embedScalingSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
      embedScalingSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
      embedScalingSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
      embedScalingSpec STransformerStyle style
SBERT = EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling
      embedScalingSpec STransformerStyle style
SRoBERTa = EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling
      embedScalingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
SDim inputEmbedDim
-> encoder
-> encoderEmbedding
-> encoderTypeEmbedding
-> head
-> EncoderOnlyTransformerHasEmbedScaling
-> GEncoderOnlyTransformer
     inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
GEncoderOnlyTransformer
        SDim inputEmbedDim
inputEmbedDim
        (STransformerStyle style
-> NamedModel
     (GTransformer
        (ModelSpec
           (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
        (ModelSpec
           (TERelPosEncF style gradient device dataType headDim posEncDim))
        (ModelSpec
           (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEInitialDropoutF style hasDropout))
        (NamedModel
           (GTransformerStack
              (VectorSpec
                 numLayers
                 (GTransformerBlock
                    (NamedModel
                       (GSelfAttention
                          (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                          (NamedModel
                             (GMultiHeadAttention
                                headDim
                                headEmbedDim
                                embedDim
                                (QInProjF style gradient device dataType inputEmbedDim embedDim)
                                (KInProjF style gradient device dataType inputEmbedDim embedDim)
                                (VInProjF style gradient device dataType inputEmbedDim embedDim)
                                (OutProjF style gradient device dataType embedDim inputEmbedDim)
                                (DropoutF style hasDropout)))
                          (SADropoutF style hasDropout)
                          (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                    ()
                    (NamedModel
                       (GTransformerFeedForwardNetwork
                          (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                          (FFNInputTransformationF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNActivationF style)
                          (FFNActivationDropoutF style hasDropout)
                          (FFNOutputProjectionF
                             style gradient device dataType inputEmbedDim ffnDim)
                          (FFNOutputDropoutF style hasDropout)
                          (FFNOutputLayerNormF
                             style gradient device dataType inputEmbedDim)))))))
        (ModelSpec
           (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
        (ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
style)
        (STransformerStyle style
-> NamedModel
     (EmbeddingSpec
        gradient
        ('Layout 'Dense)
        device
        dataType
        vocabDim
        inputEmbedDim
        'Nothing)
embeddingSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (EOTTypeEmbeddingF
        style gradient device dataType inputEmbedDim typeVocabDim)
typeEmbeddingSpec STransformerStyle style
style)
        (STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
     (EOTHeadF
        style
        transformerHead
        gradient
        device
        dataType
        inputEmbedDim
        vocabDim)
headSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead)
        (STransformerStyle style -> EncoderOnlyTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
style)
  where
    encoderSpec' :: _
    encoderSpec' :: STransformerStyle style
-> GTransformer
     (ModelSpec
        (TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
     (ModelSpec
        (TERelPosEncF style gradient device dataType headDim posEncDim))
     (ModelSpec
        (TEInitialLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEInitialDropoutF style hasDropout))
     (NamedModel
        (GTransformerStack
           (VectorSpec
              numLayers
              (GTransformerBlock
                 (NamedModel
                    (GSelfAttention
                       (SAInitialLayerNormF style gradient device dataType inputEmbedDim)
                       (NamedModel
                          (GMultiHeadAttention
                             headDim
                             headEmbedDim
                             embedDim
                             (QInProjF style gradient device dataType inputEmbedDim embedDim)
                             (KInProjF style gradient device dataType inputEmbedDim embedDim)
                             (VInProjF style gradient device dataType inputEmbedDim embedDim)
                             (OutProjF style gradient device dataType embedDim inputEmbedDim)
                             (DropoutF style hasDropout)))
                       (SADropoutF style hasDropout)
                       (SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
                 ()
                 (NamedModel
                    (GTransformerFeedForwardNetwork
                       (FFNInputLayerNormF style gradient device dataType inputEmbedDim)
                       (FFNInputTransformationF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNActivationF style)
                       (FFNActivationDropoutF style hasDropout)
                       (FFNOutputProjectionF
                          style gradient device dataType inputEmbedDim ffnDim)
                       (FFNOutputDropoutF style hasDropout)
                       (FFNOutputLayerNormF
                          style gradient device dataType inputEmbedDim)))))))
     (ModelSpec
        (TEFinalLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Nat)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (posEncDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (TransformerEncoderF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        hasDropout)
transformerEncoderSpec STransformerStyle style
style' SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    embeddingSpec' :: EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  vocabDim
  inputEmbedDim
  'Nothing
embeddingSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim vocabDim
vocabDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing
    typeEmbeddingSpec' :: EmbeddingSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  typeVocabDim
  inputEmbedDim
  'Nothing
typeEmbeddingSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (embedNumDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (paddingIdx :: Maybe Nat).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
     gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim typeVocabDim
typeVocabDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing
    headSpec' :: _
    headSpec' :: STransformerStyle style
-> GLMHead
     inputEmbedDim
     (ModelSpec
        (LMHeadDenseF style gradient device dataType inputEmbedDim))
     (ModelSpec (LMHeadActivationF style))
     (ModelSpec
        (LMHeadLayerNormF style gradient device dataType inputEmbedDim))
     (ModelSpec
        (LMHeadDecoderF
           style gradient device dataType inputEmbedDim vocabDim))
     (ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (vocabDim :: Dim (Name Symbol) (Size Nat)).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputEmbedDim
-> SDim vocabDim
-> Double
-> ModelSpec
     (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim Double
eps

instance
  ( HasInitialize encoder generatorDevice encoder' generatorDevice0,
    HasInitialize encoderEmbedding generatorDevice0 encoderEmbedding' generatorDevice1,
    HasInitialize encoderTypeEmbedding generatorDevice1 encoderTypeEmbedding' generatorDevice2,
    HasInitialize head generatorDevice2 head' generatorOutputDevice
  ) =>
  HasInitialize
    (GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
    generatorDevice
    (GEncoderOnlyTransformer inputEmbedDim encoder' encoderEmbedding' encoderTypeEmbedding' head')
    generatorOutputDevice

instance
  (HasStateDict encoder, HasStateDict encoderEmbedding, HasStateDict encoderTypeEmbedding, HasStateDict head) =>
  HasStateDict (GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)

data
  GSimplifiedEncoderOnlyTransformer
    (model :: Type)
    (mkPos :: Type)
    (mkPaddingMask :: Type)
    (mkAttentionMask :: Type)
  where
  GSimplifiedEncoderOnlyTransformer ::
    forall model mkPos mkPaddingMask mkAttentionMask.
    { -- | encoder-only model
      forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> model
seotModel :: model,
      -- | make input positions
      forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> mkPos
seotMkPos :: mkPos,
      -- | make padding mask
      forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> mkPaddingMask
seotMkPaddingMask :: mkPaddingMask,
      -- | make attention mask
      forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> mkAttentionMask
seotMkAttentionMask :: mkAttentionMask
    } ->
    GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask
  deriving stock (GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall model mkPos mkPaddingMask mkAttentionMask.
(Eq model, Eq mkPos, Eq mkPaddingMask, Eq mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
/= :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c/= :: forall model mkPos mkPaddingMask mkAttentionMask.
(Eq model, Eq mkPos, Eq mkPaddingMask, Eq mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
== :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c== :: forall model mkPos mkPaddingMask mkAttentionMask.
(Eq model, Eq mkPos, Eq mkPaddingMask, Eq mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
Eq, GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {model} {mkPos} {mkPaddingMask} {mkAttentionMask}.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
Eq
  (GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask)
forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Ordering
forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
min :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
$cmin :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
max :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
$cmax :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
>= :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c>= :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
> :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c> :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
<= :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c<= :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
< :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c< :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Bool
compare :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Ordering
$ccompare :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> Ordering
Ord, Int
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
 Show mkAttentionMask) =>
Int
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> ShowS
forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
 Show mkAttentionMask) =>
[GSimplifiedEncoderOnlyTransformer
   model mkPos mkPaddingMask mkAttentionMask]
-> ShowS
forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
 Show mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> String
showList :: [GSimplifiedEncoderOnlyTransformer
   model mkPos mkPaddingMask mkAttentionMask]
-> ShowS
$cshowList :: forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
 Show mkAttentionMask) =>
[GSimplifiedEncoderOnlyTransformer
   model mkPos mkPaddingMask mkAttentionMask]
-> ShowS
show :: GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> String
$cshow :: forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
 Show mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> String
showsPrec :: Int
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> ShowS
$cshowsPrec :: forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
 Show mkAttentionMask) =>
Int
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall model mkPos mkPaddingMask mkAttentionMask x.
Rep
  (GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask)
  x
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
forall model mkPos mkPaddingMask mkAttentionMask x.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> Rep
     (GSimplifiedEncoderOnlyTransformer
        model mkPos mkPaddingMask mkAttentionMask)
     x
$cto :: forall model mkPos mkPaddingMask mkAttentionMask x.
Rep
  (GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask)
  x
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
$cfrom :: forall model mkPos mkPaddingMask mkAttentionMask x.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> Rep
     (GSimplifiedEncoderOnlyTransformer
        model mkPos mkPaddingMask mkAttentionMask)
     x
Generic)

type instance
  ModelSpec (GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask) =
    GSimplifiedEncoderOnlyTransformer (ModelSpec model) (ModelSpec mkPos) (ModelSpec mkPaddingMask) (ModelSpec mkAttentionMask)

instance
  ( HasInitialize model generatorDevice model' generatorDevice0,
    HasInitialize mkPos generatorDevice0 mkPos' generatorDevice1,
    HasInitialize mkPaddingMask generatorDevice1 mkPaddingMask' generatorDevice2,
    HasInitialize mkAttentionMask generatorDevice2 mkAttentionMask' generatorOutputDevice
  ) =>
  HasInitialize
    (GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask)
    generatorDevice
    (GSimplifiedEncoderOnlyTransformer model' mkPos' mkPaddingMask' mkAttentionMask')
    generatorOutputDevice

instance
  (HasStateDict model, HasStateDict mkPos, HasStateDict mkPaddingMask, HasStateDict mkAttentionMask) =>
  HasStateDict (GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask)

-- | Input data type for use with an encoder-only transformer.
data EncoderOnlyTransformerInput input inputType pos attentionMask where
  EncoderOnlyTransformerInput ::
    forall input inputType pos attentionMask.
    { forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> input
eotInput :: input,
      forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> inputType
eotInputType :: inputType,
      forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> pos
eotPos :: pos,
      forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> attentionMask
eotAttentionMask :: attentionMask
    } ->
    EncoderOnlyTransformerInput input inputType pos attentionMask
  deriving stock (EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input inputType pos attentionMask.
(Eq input, Eq inputType, Eq pos, Eq attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
/= :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c/= :: forall input inputType pos attentionMask.
(Eq input, Eq inputType, Eq pos, Eq attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
== :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c== :: forall input inputType pos attentionMask.
(Eq input, Eq inputType, Eq pos, Eq attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
Eq, EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {inputType} {pos} {attentionMask}.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
Eq (EncoderOnlyTransformerInput input inputType pos attentionMask)
forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
min :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
$cmin :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
max :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
$cmax :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
>= :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c>= :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
> :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c> :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
<= :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c<= :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
< :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c< :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
compare :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
$ccompare :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
Ord, Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
[EncoderOnlyTransformerInput input inputType pos attentionMask]
-> ShowS
forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> String
showList :: [EncoderOnlyTransformerInput input inputType pos attentionMask]
-> ShowS
$cshowList :: forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
[EncoderOnlyTransformerInput input inputType pos attentionMask]
-> ShowS
show :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> String
$cshow :: forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> String
showsPrec :: Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
$cshowsPrec :: forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input inputType pos attentionMask x.
Rep
  (EncoderOnlyTransformerInput input inputType pos attentionMask) x
-> EncoderOnlyTransformerInput input inputType pos attentionMask
forall input inputType pos attentionMask x.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> Rep
     (EncoderOnlyTransformerInput input inputType pos attentionMask) x
$cto :: forall input inputType pos attentionMask x.
Rep
  (EncoderOnlyTransformerInput input inputType pos attentionMask) x
-> EncoderOnlyTransformerInput input inputType pos attentionMask
$cfrom :: forall input inputType pos attentionMask x.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> Rep
     (EncoderOnlyTransformerInput input inputType pos attentionMask) x
Generic)

data SimplifiedEncoderOnlyTransformerInput input inputType where
  SimplifiedEncoderOnlyTransformerInput ::
    forall input inputType.
    { forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> input
seotInput :: input,
      forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> inputType
seotInputType :: inputType
    } ->
    SimplifiedEncoderOnlyTransformerInput input inputType
  deriving stock (SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input inputType.
(Eq input, Eq inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
/= :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c/= :: forall input inputType.
(Eq input, Eq inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
== :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c== :: forall input inputType.
(Eq input, Eq inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
Eq, SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {inputType}.
(Ord input, Ord inputType) =>
Eq (SimplifiedEncoderOnlyTransformerInput input inputType)
forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
min :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
$cmin :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
max :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
$cmax :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
>= :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c>= :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
> :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c> :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
<= :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c<= :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
< :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c< :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
compare :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
$ccompare :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
Ord, Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input inputType.
(Show input, Show inputType) =>
Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
forall input inputType.
(Show input, Show inputType) =>
[SimplifiedEncoderOnlyTransformerInput input inputType] -> ShowS
forall input inputType.
(Show input, Show inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType -> String
showList :: [SimplifiedEncoderOnlyTransformerInput input inputType] -> ShowS
$cshowList :: forall input inputType.
(Show input, Show inputType) =>
[SimplifiedEncoderOnlyTransformerInput input inputType] -> ShowS
show :: SimplifiedEncoderOnlyTransformerInput input inputType -> String
$cshow :: forall input inputType.
(Show input, Show inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType -> String
showsPrec :: Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
$cshowsPrec :: forall input inputType.
(Show input, Show inputType) =>
Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input inputType x.
Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
-> SimplifiedEncoderOnlyTransformerInput input inputType
forall input inputType x.
SimplifiedEncoderOnlyTransformerInput input inputType
-> Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
$cto :: forall input inputType x.
Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
-> SimplifiedEncoderOnlyTransformerInput input inputType
$cfrom :: forall input inputType x.
SimplifiedEncoderOnlyTransformerInput input inputType
-> Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
Generic)

-- | Output data type for use with an encoder-only transformer.
data EncoderOnlyTransformerOutput output where
  EncoderOnlyTransformerOutput ::
    forall output.
    { forall output. EncoderOnlyTransformerOutput output -> output
eotOutput :: output
    } ->
    EncoderOnlyTransformerOutput output
  deriving stock (EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
forall output.
Eq output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c/= :: forall output.
Eq output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
== :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c== :: forall output.
Eq output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
Eq, EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {output}.
Ord output =>
Eq (EncoderOnlyTransformerOutput output)
forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
min :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
$cmin :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
max :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
$cmax :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
>= :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c>= :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
> :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c> :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
<= :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c<= :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
< :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c< :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
compare :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
$ccompare :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
Ord, Int -> EncoderOnlyTransformerOutput output -> ShowS
forall output.
Show output =>
Int -> EncoderOnlyTransformerOutput output -> ShowS
forall output.
Show output =>
[EncoderOnlyTransformerOutput output] -> ShowS
forall output.
Show output =>
EncoderOnlyTransformerOutput output -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderOnlyTransformerOutput output] -> ShowS
$cshowList :: forall output.
Show output =>
[EncoderOnlyTransformerOutput output] -> ShowS
show :: EncoderOnlyTransformerOutput output -> String
$cshow :: forall output.
Show output =>
EncoderOnlyTransformerOutput output -> String
showsPrec :: Int -> EncoderOnlyTransformerOutput output -> ShowS
$cshowsPrec :: forall output.
Show output =>
Int -> EncoderOnlyTransformerOutput output -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall output x.
Rep (EncoderOnlyTransformerOutput output) x
-> EncoderOnlyTransformerOutput output
forall output x.
EncoderOnlyTransformerOutput output
-> Rep (EncoderOnlyTransformerOutput output) x
$cto :: forall output x.
Rep (EncoderOnlyTransformerOutput output) x
-> EncoderOnlyTransformerOutput output
$cfrom :: forall output x.
EncoderOnlyTransformerOutput output
-> Rep (EncoderOnlyTransformerOutput output) x
Generic)

data SimplifiedEncoderOnlyTransformerOutput output paddingMask where
  SimplifiedEncoderOnlyTransformerOutput ::
    forall output paddingMask.
    { forall output paddingMask.
SimplifiedEncoderOnlyTransformerOutput output paddingMask -> output
seotOutput :: output,
      forall output paddingMask.
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> paddingMask
sedtPaddingMask :: paddingMask
    } ->
    SimplifiedEncoderOnlyTransformerOutput output paddingMask
  deriving stock (SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall output paddingMask.
(Eq output, Eq paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
/= :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c/= :: forall output paddingMask.
(Eq output, Eq paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
== :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c== :: forall output paddingMask.
(Eq output, Eq paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
Eq, SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {output} {paddingMask}.
(Ord output, Ord paddingMask) =>
Eq (SimplifiedEncoderOnlyTransformerOutput output paddingMask)
forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
min :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
$cmin :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
max :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
$cmax :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
>= :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c>= :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
> :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c> :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
<= :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c<= :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
< :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c< :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
compare :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
$ccompare :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall output paddingMask.
(Show output, Show paddingMask) =>
Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
forall output paddingMask.
(Show output, Show paddingMask) =>
[SimplifiedEncoderOnlyTransformerOutput output paddingMask]
-> ShowS
forall output paddingMask.
(Show output, Show paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask -> String
showList :: [SimplifiedEncoderOnlyTransformerOutput output paddingMask]
-> ShowS
$cshowList :: forall output paddingMask.
(Show output, Show paddingMask) =>
[SimplifiedEncoderOnlyTransformerOutput output paddingMask]
-> ShowS
show :: SimplifiedEncoderOnlyTransformerOutput output paddingMask -> String
$cshow :: forall output paddingMask.
(Show output, Show paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask -> String
showsPrec :: Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
$cshowsPrec :: forall output paddingMask.
(Show output, Show paddingMask) =>
Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall output paddingMask x.
Rep (SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
forall output paddingMask x.
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Rep
     (SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
$cto :: forall output paddingMask x.
Rep (SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
$cfrom :: forall output paddingMask x.
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Rep
     (SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
Generic)

-- | 'HasForward' instance for encoder-only transformers with optional scaling and head.
--
-- @
--    ┌───────┐    ┌───────────┐  ┌─────┐  ┌───────────────┐
--    │ input │    │ inputType │  │ pos │  │ attentionMask │
--    └───┬───┘    └─────┬─────┘  └──┬──┘  └──────┬────────┘
--        │              │           │            │
--        ▼              ▼           │            │
--  eotEmbedding  eotTypeEmbedding   │            │
--        ▼              ▼           │            │
-- (embedScaling)  (embedScaling)    │            │
--        │              │           │            │
--        └────►add◄─────┘           │            │
--               │                   │            │
--               ▼                   │            │
--          eotEncoder◄──────────────┘◄───────────┘
--               ▼
--           (eotHead)
--               │
--               ▼
--          ┌────────┐
--          │ output │
--          └────────┘
-- @
instance
  ( HasForward
      encoderEmbedding
      input
      generatorDevice
      embeddingOutput
      embeddingGeneratorOutputDevice,
    embeddingOutput ~ Tensor gradient' layout' device' dataType' shape',
    HasForward
      encoderTypeEmbedding
      inputType
      embeddingGeneratorOutputDevice
      typeEmbeddingOutput
      typeEmbeddingGeneratorOutputDevice,
    typeEmbeddingOutput ~ Tensor gradient'' layout'' device'' dataType'' shape'',
    HasForward
      encoder
      ( Tensor
          (gradient' <|> gradient'')
          (layout' <+> layout'')
          (device' <+> device'')
          (dataType' <+> dataType'')
          (BroadcastShapesF shape' shape''),
        pos,
        attentionMask
      )
      typeEmbeddingGeneratorOutputDevice
      encoderOutput
      encoderGeneratorOutputDevice,
    Catch (BroadcastShapesF shape' shape''),
    HasForward
      head
      encoderOutput
      encoderGeneratorOutputDevice
      headOutput
      generatorOutputDevice
  ) =>
  HasForward
    (GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
    (EncoderOnlyTransformerInput input inputType pos attentionMask)
    generatorDevice
    (EncoderOnlyTransformerOutput headOutput)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Generator generatorDevice
-> m (EncoderOnlyTransformerOutput headOutput,
      Generator generatorOutputDevice)
forward GEncoderOnlyTransformer {encoderEmbedding
encoderTypeEmbedding
encoder
head
SDim inputEmbedDim
EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling :: EncoderOnlyTransformerHasEmbedScaling
eotHead :: head
eotTypeEmbedding :: encoderTypeEmbedding
eotEmbedding :: encoderEmbedding
eotEncoder :: encoder
eotInputEmbedDim :: SDim inputEmbedDim
eotEmbedScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> EncoderOnlyTransformerHasEmbedScaling
eotHead :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> head
eotTypeEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderTypeEmbedding
eotEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderEmbedding
eotEncoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoder
eotInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
       encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
  inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> SDim inputEmbedDim
..} EncoderOnlyTransformerInput {input
inputType
pos
attentionMask
eotAttentionMask :: attentionMask
eotPos :: pos
eotInputType :: inputType
eotInput :: input
eotAttentionMask :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> attentionMask
eotPos :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> pos
eotInputType :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> inputType
eotInput :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> input
..} =
    let Double
scaling :: Double = forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
eotInputEmbedDim
        embeddedInput :: IxStateT
  m
  (Generator generatorDevice)
  (Generator embeddingGeneratorOutputDevice)
  (Tensor gradient' layout' device' dataType' shape')
embeddedInput =
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn input
eotInput
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoderEmbedding
eotEmbedding
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
                    EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
                )
                EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling
        embeddedInputType :: IxStateT
  m
  (Generator embeddingGeneratorOutputDevice)
  (Generator typeEmbeddingGeneratorOutputDevice)
  (Tensor gradient'' layout'' device'' dataType'' shape'')
embeddedInputType =
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn inputType
eotInputType
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoderTypeEmbedding
eotTypeEmbedding
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
                    EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
                )
                EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          (,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator embeddingGeneratorOutputDevice)
  (Tensor gradient' layout' device' dataType' shape')
embeddedInput forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator embeddingGeneratorOutputDevice)
  (Generator typeEmbeddingGeneratorOutputDevice)
  (Tensor gradient'' layout'' device'' dataType'' shape'')
embeddedInputType
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
add
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor
  (Or (Gradient RequiresGradient) gradient' gradient'')
  (Unify (Layout LayoutType) layout' layout'')
  (Unify (Device (DeviceType Nat)) device' device'')
  (Unify (DataType DType) dataType' dataType'')
  (BroadcastShapesF shape' shape'')
input' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoder
eotEncoder (Tensor
  (Or (Gradient RequiresGradient) gradient' gradient'')
  (Unify (Layout LayoutType) layout' layout'')
  (Unify (Device (DeviceType Nat)) device' device'')
  (Unify (DataType DType) dataType' dataType'')
  (BroadcastShapesF shape' shape'')
input', pos
eotPos, attentionMask
eotAttentionMask))
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward head
eotHead
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall output. output -> EncoderOnlyTransformerOutput output
EncoderOnlyTransformerOutput

instance
  ( HasForward
      mkPaddingMask
      input
      generatorDevice
      paddingMask
      generatorDevice,
    HasForward
      mkAttentionMask
      paddingMask
      generatorDevice
      attentionMask
      generatorDevice,
    HasForward
      mkPos
      input
      generatorDevice
      pos
      generatorDevice,
    HasForward
      model
      (EncoderOnlyTransformerInput input inputType pos attentionMask)
      generatorDevice
      (EncoderOnlyTransformerOutput output)
      generatorOutputDevice
  ) =>
  HasForward
    (GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask)
    (SimplifiedEncoderOnlyTransformerInput input inputType)
    generatorDevice
    (SimplifiedEncoderOnlyTransformerOutput output paddingMask)
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Generator generatorDevice
-> m (SimplifiedEncoderOnlyTransformerOutput output paddingMask,
      Generator generatorOutputDevice)
forward GSimplifiedEncoderOnlyTransformer {mkPaddingMask
mkAttentionMask
mkPos
model
seotMkAttentionMask :: mkAttentionMask
seotMkPaddingMask :: mkPaddingMask
seotMkPos :: mkPos
seotModel :: model
seotMkAttentionMask :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> mkAttentionMask
seotMkPaddingMask :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> mkPaddingMask
seotMkPos :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> mkPos
seotModel :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
  model mkPos mkPaddingMask mkAttentionMask
-> model
..} SimplifiedEncoderOnlyTransformerInput {input
inputType
seotInputType :: inputType
seotInput :: input
seotInputType :: forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> inputType
seotInput :: forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> input
..} =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      ( let paddingMask :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  paddingMask
paddingMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPaddingMask
seotMkPaddingMask forall a b. (a -> b) -> a -> b
$ input
seotInput
            pos :: IxStateT
  m (Generator generatorDevice) (Generator generatorDevice) pos
pos = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPos
seotMkPos forall a b. (a -> b) -> a -> b
$ input
seotInput
         in (,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  paddingMask
paddingMask forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m (Generator generatorDevice) (Generator generatorDevice) pos
pos
      )
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(paddingMask
paddingMask, pos
pos) ->
                 let attentionMask :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  attentionMask
attentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkAttentionMask
seotMkAttentionMask forall a b. (a -> b) -> a -> b
$ paddingMask
paddingMask
                  in ( forall input inputType pos attentionMask.
input
-> inputType
-> pos
-> attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
EncoderOnlyTransformerInput
                         input
seotInput
                         inputType
seotInputType
                         pos
pos
                         forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  attentionMask
attentionMask
                     )
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
seotModel
                       forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(EncoderOnlyTransformerOutput output
output) ->
                                forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$ forall output paddingMask.
output
-> paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
SimplifiedEncoderOnlyTransformerOutput output
output paddingMask
paddingMask
                            )
             )