{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}

module Torch.GraduallyTyped.NN.Transformer.T5.Common where

import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (GEncoderDecoderTransformerF, GSimplifiedEncoderDecoderTransformer (..), encoderDecoderTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkRelPos (..), MkTransformerAttentionMask (..), MkTransformerCrossAttentionMask (..), MkTransformerDecoderAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead (), STransformerStyle (SByT5, ST5), ShiftRight (..), TransformerHead (..), TransformerStyle (ByT5, T5), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>))

-- | T5 dType.
type T5DType = 'Float

-- | T5 dType singleton.
t5DType :: SDType T5DType
t5DType :: SDType T5DType
t5DType = forall {k} (a :: k). SingI a => Sing a
sing @T5DType

-- | T5 data type.
type T5DataType = 'DataType T5DType

-- | T5 data type singleton.
t5DataType :: SDataType T5DataType
t5DataType :: SDataType T5DataType
t5DataType = forall {k} (a :: k). SingI a => Sing a
sing @T5DataType

-- | T5 dropout rate.
-- 'dropout_rate = 0.1'
t5DropoutP :: Double
t5DropoutP :: Double
t5DropoutP = Double
0.1

-- | T5 relative positional encoding bucket dimension.
-- 'relative_attention_num_buckets = 32'
type T5RelPosEncBucketDim = 'Dim ('Name "*") ('Size 32)

-- | T5 relative positional encoding bucket dimension singleton.
t5RelPosEncBucketDim :: SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim :: SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim = forall {k} (a :: k). SingI a => Sing a
sing @T5RelPosEncBucketDim

-- | T5 layer-norm epsilon.
-- 'layer_norm_epsilon = 1e-06'
t5Eps :: Double
t5Eps :: Double
t5Eps = Double
1e-6

-- | T5 maximum distance for relative positional encoding.
t5MaxDistance :: Int
t5MaxDistance :: Int
t5MaxDistance = Int
128

-- | T5 padding token id.
-- 'pad_token_id = 0'
t5PadTokenId :: Int
t5PadTokenId :: Int
t5PadTokenId = Int
0

-- | T5 begin-of-sentence token id.
t5BOSTokenId :: Int
t5BOSTokenId :: Int
t5BOSTokenId = Int
t5PadTokenId

-- | T5 end-of-sentence token id.
-- 'eos_token_id = 1'
t5EOSTokenId :: Int
t5EOSTokenId :: Int
t5EOSTokenId = Int
1

-- | T5 attention mask bias
t5AttentionMaskBias :: Double
t5AttentionMaskBias :: Double
t5AttentionMaskBias = -Double
10000

-- | Specifies a T5 or ByT5 model.
type T5ModelF ::
  TransformerStyle ->
  TransformerHead ->
  Nat ->
  Nat ->
  Gradient RequiresGradient ->
  Device (DeviceType Nat) ->
  Dim (Name Symbol) (Size Nat) ->
  Dim (Name Symbol) (Size Nat) ->
  Dim (Name Symbol) (Size Nat) ->
  Dim (Name Symbol) (Size Nat) ->
  Dim (Name Symbol) (Size Nat) ->
  Dim (Name Symbol) (Size Nat) ->
  HasDropout ->
  Type
type family
  T5ModelF style transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout
  where
  T5ModelF 'T5 transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
    GSimplifiedEncoderDecoderTransformer
      (GEncoderDecoderTransformerF 'T5 transformerHead numEncoderLayers numDecoderLayers gradient device T5DataType headDim headEmbedDim embedDim inputEmbedDim ffnDim T5RelPosEncBucketDim vocabDim hasDropout)
      (MkRelPos T5RelPosEncBucketDim)
      (MkRelPos T5RelPosEncBucketDim)
      MkTransformerPaddingMask
      (MkTransformerAttentionMask T5DataType)
      (MkTransformerCrossAttentionMask T5DataType)
      (MkTransformerDecoderAttentionMask T5DataType)
  T5ModelF 'ByT5 transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
    GSimplifiedEncoderDecoderTransformer
      (GEncoderDecoderTransformerF 'ByT5 transformerHead numEncoderLayers numDecoderLayers gradient device T5DataType headDim headEmbedDim embedDim inputEmbedDim ffnDim T5RelPosEncBucketDim vocabDim hasDropout)
      (MkRelPos T5RelPosEncBucketDim)
      (MkRelPos T5RelPosEncBucketDim)
      MkTransformerPaddingMask
      (MkTransformerAttentionMask T5DataType)
      (MkTransformerCrossAttentionMask T5DataType)
      (MkTransformerDecoderAttentionMask T5DataType)

-- | Specifies the parameters of a T5 or ByT5 model.
--
-- - @transformerHead@: the head of the T5 or ByT5 model.
-- - @numLayers@: the number of layers in the T5 or ByT5 model.
-- - @gradient@: whether to compute the gradient of the T5 or ByT5 model.
-- - @device@: the computational device on which the T5 or ByT5 model parameters are to be allocated.
t5ModelSpec ::
  forall style transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout.
  ( SingI headDim,
    SingI headEmbedDim,
    SingI embedDim,
    SingI inputEmbedDim,
    SingI ffnDim,
    SingI vocabDim
  ) =>
  STransformerStyle style ->
  STransformerHead transformerHead ->
  SNat numEncoderLayers ->
  SNat numDecoderLayers ->
  SGradient gradient ->
  SDevice device ->
  SHasDropout hasDropout ->
  ModelSpec (T5ModelF style transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout)
t5ModelSpec :: forall (style :: TransformerStyle)
       (transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
       (numDecoderLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
 SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
     (T5ModelF
        style
        transformerHead
        numEncoderLayers
        numDecoderLayers
        gradient
        device
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        vocabDim
        hasDropout)
t5ModelSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead SNat numEncoderLayers
numEncoderLayers SNat numDecoderLayers
numDecoderLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
  case STransformerStyle style
style of
    STransformerStyle style
ST5 ->
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
        (forall {style :: TransformerStyle}.
STransformerStyle style
-> GEncoderDecoderTransformer
     inputEmbedDim
     (NamedModel
        (GTransformer
           (ModelSpec
              (TEPosEncF
                 style
                 gradient
                 device
                 T5DataType
                 inputEmbedDim
                 T5RelPosEncBucketDim))
           (ModelSpec
              (TERelPosEncF
                 style gradient device T5DataType headDim T5RelPosEncBucketDim))
           (ModelSpec
              (TEInitialLayerNormF
                 style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TEInitialDropoutF style hasDropout))
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    numEncoderLayers
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (SAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (SADropoutF style hasDropout)
                             (SAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       ()
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
                             (FFNInputTransformationF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNActivationF style)
                             (FFNActivationDropoutF style hasDropout)
                             (FFNOutputProjectionF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNOutputDropoutF style hasDropout)
                             (FFNOutputLayerNormF
                                style gradient device T5DataType inputEmbedDim)))))))
           (ModelSpec
              (TEFinalLayerNormF style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TEFinalDropoutF style hasDropout))))
     (NamedModel
        (GTransformer
           (ModelSpec
              (TDPosEncF
                 style
                 gradient
                 device
                 T5DataType
                 inputEmbedDim
                 T5RelPosEncBucketDim))
           (ModelSpec
              (TDRelPosEncF
                 style gradient device T5DataType headDim T5RelPosEncBucketDim))
           (ModelSpec
              (TDInitialLayerNormF
                 style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TDInitialDropoutF style hasDropout))
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    numDecoderLayers
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (SAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (SADropoutF style hasDropout)
                             (SAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       (NamedModel
                          (GCrossAttention
                             (CAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (CADropoutF style hasDropout)
                             (CAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
                             (FFNInputTransformationF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNActivationF style)
                             (FFNActivationDropoutF style hasDropout)
                             (FFNOutputProjectionF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNOutputDropoutF style hasDropout)
                             (FFNOutputLayerNormF
                                style gradient device T5DataType inputEmbedDim)))))))
           (ModelSpec
              (TDFinalLayerNormF style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TDFinalDropoutF style hasDropout))))
     (NamedModel
        (EmbeddingSpec
           gradient
           ('Layout 'Dense)
           device
           T5DataType
           vocabDim
           inputEmbedDim
           'Nothing))
     (ModelSpec
        (EDTHeadF
           style
           transformerHead
           gradient
           device
           T5DataType
           inputEmbedDim
           vocabDim))
modelSpec' STransformerStyle 'T5
ST5)
        (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
t5BOSTokenId)
        (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
        (forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
        (forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkDecoderRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
        (Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
t5PadTokenId)
        (forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
        (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
        (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
    STransformerStyle style
SByT5 ->
      forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
        (forall {style :: TransformerStyle}.
STransformerStyle style
-> GEncoderDecoderTransformer
     inputEmbedDim
     (NamedModel
        (GTransformer
           (ModelSpec
              (TEPosEncF
                 style
                 gradient
                 device
                 T5DataType
                 inputEmbedDim
                 T5RelPosEncBucketDim))
           (ModelSpec
              (TERelPosEncF
                 style gradient device T5DataType headDim T5RelPosEncBucketDim))
           (ModelSpec
              (TEInitialLayerNormF
                 style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TEInitialDropoutF style hasDropout))
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    numEncoderLayers
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (SAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (SADropoutF style hasDropout)
                             (SAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       ()
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
                             (FFNInputTransformationF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNActivationF style)
                             (FFNActivationDropoutF style hasDropout)
                             (FFNOutputProjectionF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNOutputDropoutF style hasDropout)
                             (FFNOutputLayerNormF
                                style gradient device T5DataType inputEmbedDim)))))))
           (ModelSpec
              (TEFinalLayerNormF style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TEFinalDropoutF style hasDropout))))
     (NamedModel
        (GTransformer
           (ModelSpec
              (TDPosEncF
                 style
                 gradient
                 device
                 T5DataType
                 inputEmbedDim
                 T5RelPosEncBucketDim))
           (ModelSpec
              (TDRelPosEncF
                 style gradient device T5DataType headDim T5RelPosEncBucketDim))
           (ModelSpec
              (TDInitialLayerNormF
                 style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TDInitialDropoutF style hasDropout))
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    numDecoderLayers
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (SAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (SADropoutF style hasDropout)
                             (SAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       (NamedModel
                          (GCrossAttention
                             (CAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (CADropoutF style hasDropout)
                             (CAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
                             (FFNInputTransformationF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNActivationF style)
                             (FFNActivationDropoutF style hasDropout)
                             (FFNOutputProjectionF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNOutputDropoutF style hasDropout)
                             (FFNOutputLayerNormF
                                style gradient device T5DataType inputEmbedDim)))))))
           (ModelSpec
              (TDFinalLayerNormF style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TDFinalDropoutF style hasDropout))))
     (NamedModel
        (EmbeddingSpec
           gradient
           ('Layout 'Dense)
           device
           T5DataType
           vocabDim
           inputEmbedDim
           'Nothing))
     (ModelSpec
        (EDTHeadF
           style
           transformerHead
           gradient
           device
           T5DataType
           inputEmbedDim
           vocabDim))
modelSpec' STransformerStyle 'ByT5
SByT5)
        (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
t5BOSTokenId)
        (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
        (forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
        (forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkDecoderRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
        (Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
t5PadTokenId)
        (forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
        (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
        (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
    STransformerStyle style
_ -> forall a. HasCallStack => a
undefined
  where
    modelSpec' :: _
    modelSpec' :: STransformerStyle style
-> GEncoderDecoderTransformer
     inputEmbedDim
     (NamedModel
        (GTransformer
           (ModelSpec
              (TEPosEncF
                 style
                 gradient
                 device
                 T5DataType
                 inputEmbedDim
                 T5RelPosEncBucketDim))
           (ModelSpec
              (TERelPosEncF
                 style gradient device T5DataType headDim T5RelPosEncBucketDim))
           (ModelSpec
              (TEInitialLayerNormF
                 style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TEInitialDropoutF style hasDropout))
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    numEncoderLayers
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (SAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (SADropoutF style hasDropout)
                             (SAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       ()
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
                             (FFNInputTransformationF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNActivationF style)
                             (FFNActivationDropoutF style hasDropout)
                             (FFNOutputProjectionF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNOutputDropoutF style hasDropout)
                             (FFNOutputLayerNormF
                                style gradient device T5DataType inputEmbedDim)))))))
           (ModelSpec
              (TEFinalLayerNormF style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TEFinalDropoutF style hasDropout))))
     (NamedModel
        (GTransformer
           (ModelSpec
              (TDPosEncF
                 style
                 gradient
                 device
                 T5DataType
                 inputEmbedDim
                 T5RelPosEncBucketDim))
           (ModelSpec
              (TDRelPosEncF
                 style gradient device T5DataType headDim T5RelPosEncBucketDim))
           (ModelSpec
              (TDInitialLayerNormF
                 style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TDInitialDropoutF style hasDropout))
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    numDecoderLayers
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (SAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (SADropoutF style hasDropout)
                             (SAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       (NamedModel
                          (GCrossAttention
                             (CAInitialLayerNormF
                                style gradient device T5DataType inputEmbedDim)
                             (NamedModel
                                (GMultiHeadAttention
                                   headDim
                                   headEmbedDim
                                   embedDim
                                   (QInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (KInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (VInProjF
                                      style gradient device T5DataType inputEmbedDim embedDim)
                                   (OutProjF
                                      style gradient device T5DataType embedDim inputEmbedDim)
                                   (DropoutF style hasDropout)))
                             (CADropoutF style hasDropout)
                             (CAFinalLayerNormF
                                style gradient device T5DataType inputEmbedDim)))
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
                             (FFNInputTransformationF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNActivationF style)
                             (FFNActivationDropoutF style hasDropout)
                             (FFNOutputProjectionF
                                style gradient device T5DataType inputEmbedDim ffnDim)
                             (FFNOutputDropoutF style hasDropout)
                             (FFNOutputLayerNormF
                                style gradient device T5DataType inputEmbedDim)))))))
           (ModelSpec
              (TDFinalLayerNormF style gradient device T5DataType inputEmbedDim))
           (ModelSpec (TDFinalDropoutF style hasDropout))))
     (NamedModel
        (EmbeddingSpec
           gradient
           ('Layout 'Dense)
           device
           T5DataType
           vocabDim
           inputEmbedDim
           'Nothing))
     (ModelSpec
        (EDTHeadF
           style
           transformerHead
           gradient
           device
           T5DataType
           inputEmbedDim
           vocabDim))
modelSpec' STransformerStyle style
style' =
      forall (style :: TransformerStyle)
       (transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
       (numDecoderLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (posEncDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GEncoderDecoderTransformerF
        style
        transformerHead
        numEncoderLayers
        numDecoderLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        vocabDim
        hasDropout)
encoderDecoderTransformerSpec
        STransformerStyle style
style'
        STransformerHead transformerHead
transformerHead
        SNat numEncoderLayers
numEncoderLayers
        SNat numDecoderLayers
numDecoderLayers
        SGradient gradient
gradient
        SDevice device
device
        SDataType T5DataType
t5DataType
        (forall {k} (a :: k). SingI a => Sing a
sing @headDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
        SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim
        (forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
        SHasDropout hasDropout
hasDropout
        Double
t5DropoutP
        Double
t5Eps

mkT5Input ::
  forall batchDim seqDim device m output.
  ( MonadThrow m,
    SGetDim batchDim,
    SGetDim seqDim,
    Catch
      ( 'Shape
          '[ 'Dim ('Name "*") 'UncheckedSize,
             'Dim ('Name "*") 'UncheckedSize
           ]
          <+> 'Shape '[batchDim, seqDim]
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          device
          ('DataType 'Int64)
          ('Shape '[batchDim, seqDim])
  ) =>
  SDim batchDim ->
  SDim seqDim ->
  SDevice device ->
  [[Int]] ->
  m output
mkT5Input :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input = forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
t5PadTokenId