{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin TypeLevel.Rewrite
                -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyRightAssociativeL #-}

module Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention where

import Control.Monad.Catch (MonadThrow)
import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (..))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Functional.NonLinearActivation (SoftmaxF, softmax)
import Torch.GraduallyTyped.NN.Linear (GLinear (..), GLinearF, linearSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF, sGetDimFromShape, sUnifyDim, type (!))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim (..), SSelectDim (..), SShape (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (ReshapeF, TransposeF, sReshape, sTranspose)
import Torch.GraduallyTyped.Tensor.MathOperations.BlasLapack (MatmulF, matmul)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add, mulScalar)
import Torch.GraduallyTyped.Tensor.Type (SGetShape (..), Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

-- | Data type for representing whether or not (and, if so, where) scaling is applied in the multi-headed attention layer.
data MultiHeadAttentionHasScaling
  = -- | Scaling is not done.
    MultiHeadAttentionWithoutScaling
  | -- | Scaling is applied to the query after in the in-projection.
    MultiHeadAttentionWithQueryScaling
  | -- | Scaling is applied to the attention weights.
    MultiHeadAttentionWithWeightScaling
  deriving stock (MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c/= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
== :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c== :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
Eq, Eq MultiHeadAttentionHasScaling
MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Ordering
MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
$cmin :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
max :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
$cmax :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
>= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c>= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
> :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c> :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
<= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c<= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
< :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c< :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
compare :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Ordering
$ccompare :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Ordering
Ord, Int -> MultiHeadAttentionHasScaling -> ShowS
[MultiHeadAttentionHasScaling] -> ShowS
MultiHeadAttentionHasScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MultiHeadAttentionHasScaling] -> ShowS
$cshowList :: [MultiHeadAttentionHasScaling] -> ShowS
show :: MultiHeadAttentionHasScaling -> String
$cshow :: MultiHeadAttentionHasScaling -> String
showsPrec :: Int -> MultiHeadAttentionHasScaling -> ShowS
$cshowsPrec :: Int -> MultiHeadAttentionHasScaling -> ShowS
Show, forall x.
Rep MultiHeadAttentionHasScaling x -> MultiHeadAttentionHasScaling
forall x.
MultiHeadAttentionHasScaling -> Rep MultiHeadAttentionHasScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep MultiHeadAttentionHasScaling x -> MultiHeadAttentionHasScaling
$cfrom :: forall x.
MultiHeadAttentionHasScaling -> Rep MultiHeadAttentionHasScaling x
Generic)

type instance ModelSpec MultiHeadAttentionHasScaling = MultiHeadAttentionHasScaling

instance HasInitialize MultiHeadAttentionHasScaling generatorDevice MultiHeadAttentionHasScaling generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec MultiHeadAttentionHasScaling
-> Generator generatorDevice
-> m (MultiHeadAttentionHasScaling, Generator generatorDevice)
initialize ModelSpec MultiHeadAttentionHasScaling
hasScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec MultiHeadAttentionHasScaling
hasScaling, Generator generatorDevice
g)

instance HasStateDict MultiHeadAttentionHasScaling where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec MultiHeadAttentionHasScaling
-> StateDictKey -> m MultiHeadAttentionHasScaling
fromStateDict ModelSpec MultiHeadAttentionHasScaling
hasScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec MultiHeadAttentionHasScaling
hasScaling
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MultiHeadAttentionHasScaling -> m ()
toStateDict StateDictKey
_ MultiHeadAttentionHasScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Generic multi-headed attention layer.
--
-- - @headDim@ is the dimension of the attention heads.
-- - @headEmbedDim@ is the dimension of the attention head embedding.
-- - @embedDim@ is the dimension of the embedding.
-- - @qInProj@ is the type of the query projection.
-- - @kInProj@ is the type of the key projection.
-- - @vInProj@ is the type of the value projection.
-- - @outProj@ is the type of the output projection.
-- - @dropout@ is the type of the dropout layer.
data
  GMultiHeadAttention
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (qInProj :: Type)
    (kInProj :: Type)
    (vInProj :: Type)
    (outProj :: Type)
    (dropout :: Type)
  where
  GMultiHeadAttention ::
    forall headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout.
    { -- | head dim
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> SDim headDim
mhaHeadDim :: SDim headDim,
      -- | head embed dim
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> SDim headEmbedDim
mhaHeadEmbedDim :: SDim headEmbedDim,
      -- | embed dim
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> SDim embedDim
mhaEmbedDim :: SDim embedDim,
      -- | in-projection for query
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> qInProj
mhaQInProj :: qInProj,
      -- | in-projection for key
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> kInProj
mhaKInProj :: kInProj,
      -- | in-projection for value
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> vInProj
mhaVInProj :: vInProj,
      -- | out-projection
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> outProj
mhaOutProj :: outProj,
      -- | dropout
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> dropout
mhaDropout :: dropout,
      -- | scaling
      forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> MultiHeadAttentionHasScaling
mhaScaling :: MultiHeadAttentionHasScaling
    } ->
    GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout
  deriving stock (Int
-> GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
 Show dropout) =>
Int
-> GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout
-> ShowS
forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
 Show dropout) =>
[GMultiHeadAttention
   headDim
   headEmbedDim
   embedDim
   qInProj
   kInProj
   vInProj
   outProj
   dropout]
-> ShowS
forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
 Show dropout) =>
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> String
showList :: [GMultiHeadAttention
   headDim
   headEmbedDim
   embedDim
   qInProj
   kInProj
   vInProj
   outProj
   dropout]
-> ShowS
$cshowList :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
 Show dropout) =>
[GMultiHeadAttention
   headDim
   headEmbedDim
   embedDim
   qInProj
   kInProj
   vInProj
   outProj
   dropout]
-> ShowS
show :: GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> String
$cshow :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
 Show dropout) =>
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> String
showsPrec :: Int
-> GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout
-> ShowS
$cshowsPrec :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
 Show dropout) =>
Int
-> GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout x.
Rep
  (GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout)
  x
-> GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout
forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout x.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> Rep
     (GMultiHeadAttention
        headDim
        headEmbedDim
        embedDim
        qInProj
        kInProj
        vInProj
        outProj
        dropout)
     x
$cto :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout x.
Rep
  (GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout)
  x
-> GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout
$cfrom :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout x.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> Rep
     (GMultiHeadAttention
        headDim
        headEmbedDim
        embedDim
        qInProj
        kInProj
        vInProj
        outProj
        dropout)
     x
Generic)

type instance
  ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) =
    GMultiHeadAttention headDim headEmbedDim embedDim (ModelSpec qInProj) (ModelSpec kInProj) (ModelSpec vInProj) (ModelSpec outProj) (ModelSpec dropout)

type family
  GMultiHeadAttentionF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
    (valueEmbedDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout =
    GMultiHeadAttention
      headDim
      headEmbedDim
      embedDim
      (QInProjF style gradient device dataType queryEmbedDim embedDim)
      (KInProjF style gradient device dataType keyEmbedDim embedDim)
      (VInProjF style gradient device dataType valueEmbedDim embedDim)
      (OutProjF style gradient device dataType embedDim queryEmbedDim)
      (DropoutF style hasDropout)

-- | Specifies the linear transformation of the query.
type family
  QInProjF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  QInProjF 'T5 gradient device dataType queryEmbedDim embedDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim embedDim)
  QInProjF 'ByT5 gradient device dataType queryEmbedDim embedDim =
    QInProjF 'T5 gradient device dataType queryEmbedDim embedDim
  QInProjF _ gradient device dataType queryEmbedDim embedDim =
    NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim embedDim)

-- | Specifies the linear transformation of the key.
type family
  KInProjF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  KInProjF 'T5 gradient device dataType keyEmbedDim embedDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType keyEmbedDim embedDim)
  KInProjF 'ByT5 gradient device dataType keyEmbedDim embedDim =
    KInProjF 'T5 gradient device dataType keyEmbedDim embedDim
  KInProjF _ gradient device dataType keyEmbedDim embedDim =
    NamedModel (GLinearF 'WithBias gradient device dataType keyEmbedDim embedDim)

-- | Specifies the linear transformation of the value.
type family
  VInProjF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (valueEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  VInProjF 'T5 gradient device dataType valueEmbedDim embedDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType valueEmbedDim embedDim)
  VInProjF 'ByT5 gradient device dataType valueEmbedDim embedDim =
    VInProjF 'T5 gradient device dataType valueEmbedDim embedDim
  VInProjF _ gradient device dataType valueEmbedDim embedDim =
    NamedModel (GLinearF 'WithBias gradient device dataType valueEmbedDim embedDim)

-- | Specifies the type of the out-projection layer.
type family
  OutProjF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  OutProjF 'T5 gradient device dataType embedDim queryEmbedDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType embedDim queryEmbedDim)
  OutProjF 'ByT5 gradient device dataType embedDim queryEmbedDim =
    OutProjF 'T5 gradient device dataType embedDim queryEmbedDim
  OutProjF _ gradient device dataType embedDim queryEmbedDim =
    NamedModel (GLinearF 'WithBias gradient device dataType embedDim queryEmbedDim)

-- | Specifies the type of the dropout layer.
type family
  DropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  DropoutF _ 'WithDropout = Dropout
  DropoutF _ 'WithoutDropout = ()

-- | Specifies the parameters of a multi-headed attention layer.
--
-- - @style@: the style of the attention layer, e.g. 'ST5', 'ByT5', etc.
-- - @gradient@: whether to compute the gradient of the attention layer.
-- - @device@: the computational device on which to allocate the attention layer.
-- - @dataType@: the data type of the attention layer.
-- - @headDim@: the dimension of the attention heads.
-- - @headEmbedDim@: the dimension of the attention head embeddings.
-- - @embedDim@: the dimension of the input embeddings.
-- - @queryEmbedDim@: the dimension of the query embeddings.
-- - @keyEmbedDim@: the dimension of the key embeddings.
-- - @valueEmbedDim@: the dimension of the value embeddings.
-- - @dropoutP@: the dropout rate.
multiHeadAttentionSpec ::
  forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout.
  STransformerStyle style ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim queryEmbedDim ->
  SDim keyEmbedDim ->
  SDim valueEmbedDim ->
  SHasDropout hasDropout ->
  Double ->
  ModelSpec (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout)
multiHeadAttentionSpec :: forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
       (valueEmbedDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim valueEmbedDim
-> SHasDropout hasDropout
-> Double
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        valueEmbedDim
        hasDropout)
multiHeadAttentionSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim valueEmbedDim
valueEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP =
  let qInProjSpec :: STransformerStyle style
-> ModelSpec
     (QInProjF style gradient device dataType queryEmbedDim embedDim)
qInProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
      qInProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
      qInProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
      qInProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
      qInProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
      qInProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.query." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
      qInProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.query." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
      qInProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      kInProjSpec :: STransformerStyle style
-> ModelSpec
     (KInProjF style gradient device dataType keyEmbedDim embedDim)
kInProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
      kInProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
      kInProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
      kInProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
      kInProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
      kInProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.key." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
      kInProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.key." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
      kInProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      vInProjSpec :: STransformerStyle style
-> ModelSpec
     (VInProjF style gradient device dataType valueEmbedDim embedDim)
vInProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
      vInProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
      vInProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
      vInProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
      vInProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
      vInProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.value." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
      vInProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.value." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
      vInProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      outProjSpec :: STransformerStyle style
-> ModelSpec
     (OutProjF style gradient device dataType embedDim queryEmbedDim)
outProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"o." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
      outProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"o." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
      outProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"out_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
      outProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"out_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
      outProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"out_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
      outProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
      outProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
      outProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      dropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout -> ModelSpec (DropoutF style hasDropout)
dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithoutDropout = ()
      scaling :: STransformerStyle style -> MultiHeadAttentionHasScaling
      scaling :: STransformerStyle style -> MultiHeadAttentionHasScaling
scaling STransformerStyle style
ST5 = MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling
      scaling STransformerStyle style
SByT5 = MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling
      scaling STransformerStyle style
SBART = MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling
      scaling STransformerStyle style
SMBART = MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling
      scaling STransformerStyle style
SPegasus = MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling
      scaling STransformerStyle style
SBERT = MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling
      scaling STransformerStyle style
SRoBERTa = MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling
      scaling STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> qInProj
-> kInProj
-> vInProj
-> outProj
-> dropout
-> MultiHeadAttentionHasScaling
-> GMultiHeadAttention
     headDim
     headEmbedDim
     embedDim
     qInProj
     kInProj
     vInProj
     outProj
     dropout
GMultiHeadAttention
        SDim headDim
headDim
        SDim headEmbedDim
headEmbedDim
        SDim embedDim
embedDim
        (STransformerStyle style
-> ModelSpec
     (QInProjF style gradient device dataType queryEmbedDim embedDim)
qInProjSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (KInProjF style gradient device dataType keyEmbedDim embedDim)
kInProjSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (VInProjF style gradient device dataType valueEmbedDim embedDim)
vInProjSpec STransformerStyle style
style)
        (STransformerStyle style
-> ModelSpec
     (OutProjF style gradient device dataType embedDim queryEmbedDim)
outProjSpec STransformerStyle style
style)
        (STransformerStyle style
-> SHasDropout hasDropout -> ModelSpec (DropoutF style hasDropout)
dropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
        (STransformerStyle style -> MultiHeadAttentionHasScaling
scaling STransformerStyle style
style)
  where
    projSpecWithoutBias ::
      forall inputDim outputDim.
      SDim inputDim ->
      SDim outputDim ->
      ModelSpec
        ( GLinear
            (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
            (NamedModel ())
        )
    projSpecWithoutBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel ()))
projSpecWithoutBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType
    projSpecWithBias ::
      forall inputDim outputDim.
      SDim inputDim ->
      SDim outputDim ->
      ModelSpec
        ( GLinear
            (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
            (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
        )
    projSpecWithBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinear
        (NamedModel
           (Tensor
              gradient
              ('Layout 'Dense)
              device
              dataType
              ('Shape '[outputDim, inputDim])))
        (NamedModel
           (Tensor
              gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType

instance
  ( HasInitialize qInProj generatorDevice qInProj' generatorDevice0,
    HasInitialize kInProj generatorDevice0 kInProj' generatorDevice1,
    HasInitialize vInProj generatorDevice1 vInProj' generatorDevice2,
    HasInitialize outProj generatorDevice2 outProj' generatorDevice3,
    HasInitialize dropout generatorDevice3 dropout' generatorOutputDevice
  ) =>
  HasInitialize
    (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout)
    generatorDevice
    (GMultiHeadAttention headDim headEmbedDim embedDim qInProj' kInProj' vInProj' outProj' dropout')
    generatorOutputDevice

instance
  ( HasStateDict qInProj,
    HasStateDict vInProj,
    HasStateDict kInProj,
    HasStateDict outProj,
    HasStateDict dropout
  ) =>
  HasStateDict (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout)

type BatchDim ::
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Dim (Name Symbol) (Size Nat)

type BatchDim queryShape keyShape valueShape =
  (queryShape ! 0) <+> (keyShape ! 0) <+> (valueShape ! 0)

getBatchDim ::
  forall m queryShape keyShape valueShape batchDim.
  (MonadThrow m, batchDim ~ BatchDim queryShape keyShape valueShape) =>
  SShape queryShape ->
  SShape keyShape ->
  SShape valueShape ->
  m (SDim batchDim)
getBatchDim :: forall (m :: * -> *)
       (queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (batchDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m,
 batchDim ~ BatchDim queryShape keyShape valueShape) =>
SShape queryShape
-> SShape keyShape -> SShape valueShape -> m (SDim batchDim)
getBatchDim SShape queryShape
queryShape SShape keyShape
keyShape SShape valueShape
valueShape = do
  SDim (GetDimF ('SelectDim ('ByIndex 0)) queryShape)
queryBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape queryShape
queryShape
  SDim (GetDimF ('SelectDim ('ByIndex 0)) keyShape)
keyBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape keyShape
keyShape
  SDim (GetDimF ('SelectDim ('ByIndex 0)) valueShape)
valueBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape valueShape
valueShape
  SDim
  (Unify
     (Dim (Name Symbol) (Size Nat))
     (GetDimF ('SelectDim ('ByIndex 0)) keyShape)
     (GetDimF ('SelectDim ('ByIndex 0)) valueShape))
keyValueBatchDim <- forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
       (dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim SDim (GetDimF ('SelectDim ('ByIndex 0)) keyShape)
keyBatchDim SDim (GetDimF ('SelectDim ('ByIndex 0)) valueShape)
valueBatchDim
  forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
       (dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim SDim (GetDimF ('SelectDim ('ByIndex 0)) queryShape)
queryBatchDim SDim
  (Unify
     (Dim (Name Symbol) (Size Nat))
     (GetDimF ('SelectDim ('ByIndex 0)) keyShape)
     (GetDimF ('SelectDim ('ByIndex 0)) valueShape))
keyValueBatchDim

type QuerySeqDim ::
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Dim (Name Symbol) (Size Nat)

type QuerySeqDim queryShape =
  queryShape ! 1

getQuerySeqDim ::
  forall m queryShape querySeqDim.
  (MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) =>
  SShape queryShape ->
  m (SDim querySeqDim)
getQuerySeqDim :: forall (m :: * -> *)
       (queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (querySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) =>
SShape queryShape -> m (SDim querySeqDim)
getQuerySeqDim = forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)

type KeySeqDim ::
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Dim (Name Symbol) (Size Nat)

type KeySeqDim keyShape valueShape =
  (keyShape ! 1) <+> (valueShape ! 1)

getKeySeqDim ::
  forall m keyShape valueShape keySeqDim.
  (MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) =>
  SShape keyShape ->
  SShape valueShape ->
  m (SDim keySeqDim)
getKeySeqDim :: forall (m :: * -> *)
       (keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (keySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) =>
SShape keyShape -> SShape valueShape -> m (SDim keySeqDim)
getKeySeqDim SShape keyShape
keyShape SShape valueShape
valueShape =
  do
    SDim (GetDimF ('SelectDim ('ByIndex 1)) keyShape)
keySeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape keyShape
keyShape
    SDim (GetDimF ('SelectDim ('ByIndex 1)) valueShape)
valueSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape valueShape
valueShape
    forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
       (dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim SDim (GetDimF ('SelectDim ('ByIndex 1)) keyShape)
keySeqDim SDim (GetDimF ('SelectDim ('ByIndex 1)) valueShape)
valueSeqDim

-- | 'HasForward' instance for 'GMultiHeadAttention'.
--
-- @
-- ┌───────────────┐        ┌───────┐       ┌─────┐       ┌───────┐
-- │ attentionBias │        │ query │       │ key │       │ value │
-- └───────┬───────┘        └───┬───┘       └──┬──┘       └───┬───┘
--         │                    │              │              │
--         │                    ▼              ▼              ▼
--         │                mhaQInProj     mhaKInProj     mhaVInProj
--         │                    ▼              │              │
--         │                (scaling)          │              │
--         │                    ▼              ▼              ▼
--         │                 reshape        reshape        reshape
--         │                    ▼              ▼              ▼
--         │                transpose      transpose      transpose
--         │                    │              ▼              │
--         │                    │          transpose          │
--         │                    │              │              │
--         │                    └───►matmul◄───┘              │
--         │                           ▼                      │
--         │                       (scaling)                  │
--         │                           │                      │
--         └──────────►add◄────────────┘                      │
--                      ▼                                     │
--                   softmax                                  │
--                      ▼                                     │
--                  mhaDropout                                │
--                      │                                     │
--                      └──────────────►matmul◄───────────────┘
--                                        ▼
--                                    transpose
--                                        ▼
--                                     reshape
--                                        ▼
--                                    mhaOutProj
--                                        │
--                                        ▼
--                                    ┌───────┐
--                                    │ query │
--                                    └───────┘
-- @
instance
  ( HasForward
      qInProj
      (Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape)
      generatorDevice
      (Tensor qRequiresGradient qLayout qDevice qDataType qShape0)
      qGeneratorOutputDevice,
    reshapedQShape0 ~ ReshapeF qShape0 ('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]),
    Catch reshapedQShape0,
    qShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) reshapedQShape0,
    Catch qShape,
    HasForward
      kInProj
      (Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape)
      qGeneratorOutputDevice
      (Tensor qRequiresGradient kLayout kDevice kDataType kShape0)
      kGeneratorOutputDevice,
    reshapedKShape0 ~ ReshapeF kShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]),
    Catch reshapedKShape0,
    transposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) reshapedKShape0,
    Catch transposedReshapedKShape0,
    doubleTransposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 2)) ('SelectDim ('ByIndex 3)) transposedReshapedKShape0,
    Catch doubleTransposedReshapedKShape0,
    multipliedQDoubleTransposedReshapedKShape0 ~ MatmulF qShape doubleTransposedReshapedKShape0,
    Catch multipliedQDoubleTransposedReshapedKShape0,
    weightsShape0
      ~ SoftmaxF
          ('SelectDim ('ByIndex 3))
          (BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape),
    Catch (BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape),
    Catch weightsShape0,
    HasForward
      dropout
      ( Tensor
          (qRequiresGradient <|> attentionBiasRequiresGradient)
          (qLayout <+> kLayout <+> attentionBiasLayout)
          (qDevice <+> kDevice <+> attentionBiasDevice)
          (qDataType <+> kDataType <+> attentionBiasDataType)
          weightsShape0
      )
      kGeneratorOutputDevice
      (Tensor weightsRequiresGradient weightsLayout weightsDevice weightsDataType weightsShape)
      weightsGeneratorOutputDevice,
    HasForward
      vInProj
      (Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape)
      weightsGeneratorOutputDevice
      (Tensor weightsRequiresGradient vLayout vDevice vDataType vShape0)
      vGeneratorOutputDevice,
    reshapedVShape0 ~ ReshapeF vShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]),
    Catch reshapedVShape0,
    transposedReshapedVShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) reshapedVShape0,
    Catch transposedReshapedVShape,
    multipliedWeightsTransposedReshapedVShape ~ MatmulF weightsShape transposedReshapedVShape,
    Catch multipliedWeightsTransposedReshapedVShape,
    outputQueryShape0 ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) multipliedWeightsTransposedReshapedVShape,
    Catch outputQueryShape0,
    HasForward
      outProj
      ( Tensor
          weightsRequiresGradient
          (weightsLayout <+> vLayout)
          (weightsDevice <+> vDevice)
          (weightsDataType <+> vDataType)
          reshapedOutputQueryShape0
      )
      vGeneratorOutputDevice
      output
      generatorOutputDevice,
    reshapedOutputQueryShape0 ~ ReshapeF outputQueryShape0 ('Shape '[batchDim, querySeqDim, embedDim]),
    Catch reshapedOutputQueryShape0,
    SGetShape queryShape,
    SGetShape keyShape,
    SGetShape valueShape,
    batchDim ~ BatchDim queryShape keyShape valueShape,
    querySeqDim ~ QuerySeqDim queryShape,
    keySeqDim ~ KeySeqDim keyShape valueShape
  ) =>
  HasForward
    (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout)
    ( Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape,
      Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape,
      Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape,
      Tensor attentionBiasRequiresGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape
    )
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> (Tensor
      queryRequiresGradient
      queryLayout
      queryDevice
      queryDataType
      queryShape,
    Tensor
      keyRequiresGradient keyLayout keyDevice keyDataType keyShape,
    Tensor
      valueRequiresGradient
      valueLayout
      valueDevice
      valueDataType
      valueShape,
    Tensor
      attentionBiasRequiresGradient
      attentionBiasLayout
      attentionBiasDevice
      attentionBiasDataType
      attentionBiasShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GMultiHeadAttention {qInProj
kInProj
dropout
vInProj
outProj
SDim headDim
SDim headEmbedDim
SDim embedDim
MultiHeadAttentionHasScaling
mhaScaling :: MultiHeadAttentionHasScaling
mhaDropout :: dropout
mhaOutProj :: outProj
mhaVInProj :: vInProj
mhaKInProj :: kInProj
mhaQInProj :: qInProj
mhaEmbedDim :: SDim embedDim
mhaHeadEmbedDim :: SDim headEmbedDim
mhaHeadDim :: SDim headDim
mhaScaling :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> MultiHeadAttentionHasScaling
mhaDropout :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> dropout
mhaOutProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> outProj
mhaVInProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> vInProj
mhaKInProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> kInProj
mhaQInProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> qInProj
mhaEmbedDim :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> SDim embedDim
mhaHeadEmbedDim :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> SDim headEmbedDim
mhaHeadDim :: forall (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
       outProj dropout.
GMultiHeadAttention
  headDim
  headEmbedDim
  embedDim
  qInProj
  kInProj
  vInProj
  outProj
  dropout
-> SDim headDim
..} (Tensor
  queryRequiresGradient
  queryLayout
  queryDevice
  queryDataType
  queryShape
query, Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key, Tensor
  valueRequiresGradient
  valueLayout
  valueDevice
  valueDataType
  valueShape
value, Tensor
  attentionBiasRequiresGradient
  attentionBiasLayout
  attentionBiasDevice
  attentionBiasDataType
  attentionBiasShape
attentionBias) Generator generatorDevice
g = do
    SDim batchDim
batchDim <-
      let queryShape :: SShape queryShape
queryShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
  queryRequiresGradient
  queryLayout
  queryDevice
  queryDataType
  queryShape
query
          keyShape :: SShape keyShape
keyShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key
          valueShape :: SShape valueShape
valueShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
  valueRequiresGradient
  valueLayout
  valueDevice
  valueDataType
  valueShape
value
       in forall (m :: * -> *)
       (queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (batchDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m,
 batchDim ~ BatchDim queryShape keyShape valueShape) =>
SShape queryShape
-> SShape keyShape -> SShape valueShape -> m (SDim batchDim)
getBatchDim SShape queryShape
queryShape SShape keyShape
keyShape SShape valueShape
valueShape
    SDim querySeqDim
querySeqDim <-
      let queryShape :: SShape queryShape
queryShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
  queryRequiresGradient
  queryLayout
  queryDevice
  queryDataType
  queryShape
query
       in forall (m :: * -> *)
       (queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (querySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) =>
SShape queryShape -> m (SDim querySeqDim)
getQuerySeqDim SShape queryShape
queryShape
    SDim keySeqDim
keySeqDim <-
      let keyShape :: SShape keyShape
keyShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key
          valueShape :: SShape valueShape
valueShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
  valueRequiresGradient
  valueLayout
  valueDevice
  valueDataType
  valueShape
value
       in forall (m :: * -> *)
       (keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (keySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) =>
SShape keyShape -> SShape valueShape -> m (SDim keySeqDim)
getKeySeqDim SShape keyShape
keyShape SShape valueShape
valueShape
    let scaling :: Double
scaling = (Double
1 :: Double) forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim headEmbedDim
mhaHeadEmbedDim)
    forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT Generator generatorDevice
g forall a b. (a -> b) -> a -> b
$
      let q :: IxStateT
  m
  (Generator generatorDevice)
  (Generator qGeneratorOutputDevice)
  (Tensor
     qRequiresGradient
     qLayout
     qDevice
     qDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (ReshapeImplF
           (NumelF qShape0)
           (LiftTimesMaybe
              (NumelDimF batchDim)
              (LiftTimesMaybe
                 (NumelDimF querySeqDim)
                 (LiftTimesMaybe
                    (NumelDimF headDim)
                    (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
           qShape0
           ('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]))))
q =
            forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  queryRequiresGradient
  queryLayout
  queryDevice
  queryDataType
  queryShape
query
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward qInProj
mhaQInProj
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
                forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
                      MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                      MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
                      MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                  )
                  MultiHeadAttentionHasScaling
mhaScaling
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim querySeqDim
querySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headDim
mhaHeadDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headEmbedDim
mhaHeadEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
          k :: IxStateT
  m
  (Generator qGeneratorOutputDevice)
  (Generator kGeneratorOutputDevice)
  (Tensor
     qRequiresGradient
     kLayout
     kDevice
     kDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (ReshapeImplF
           (NumelF kShape0)
           (LiftTimesMaybe
              (NumelDimF batchDim)
              (LiftTimesMaybe
                 (NumelDimF keySeqDim)
                 (LiftTimesMaybe
                    (NumelDimF headDim)
                    (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
           kShape0
           ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
k =
            forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward kInProj
mhaKInProj
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim keySeqDim
keySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headDim
mhaHeadDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headEmbedDim
mhaHeadEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
          kt :: IxStateT
  m
  (Generator qGeneratorOutputDevice)
  (Generator kGeneratorOutputDevice)
  (Tensor
     qRequiresGradient
     kLayout
     kDevice
     kDataType
     (TransposeF
        ('SelectDim ('ByIndex 2))
        ('SelectDim ('ByIndex 3))
        (TransposeF
           ('SelectDim ('ByIndex 1))
           ('SelectDim ('ByIndex 2))
           (ReshapeImplF
              (NumelF kShape0)
              (LiftTimesMaybe
                 (NumelDimF batchDim)
                 (LiftTimesMaybe
                    (NumelDimF keySeqDim)
                    (LiftTimesMaybe
                       (NumelDimF headDim)
                       (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
              kShape0
              ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim])))))
kt = IxStateT
  m
  (Generator qGeneratorOutputDevice)
  (Generator kGeneratorOutputDevice)
  (Tensor
     qRequiresGradient
     kLayout
     kDevice
     kDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (ReshapeImplF
           (NumelF kShape0)
           (LiftTimesMaybe
              (NumelDimF batchDim)
              (LiftTimesMaybe
                 (NumelDimF keySeqDim)
                 (LiftTimesMaybe
                    (NumelDimF headDim)
                    (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
           kShape0
           ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
k forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @3))
          weights :: IxStateT
  m
  (Generator generatorDevice)
  (Generator weightsGeneratorOutputDevice)
  (Tensor
     weightsRequiresGradient
     weightsLayout
     weightsDevice
     weightsDataType
     weightsShape)
weights =
            (,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator qGeneratorOutputDevice)
  (Tensor
     qRequiresGradient
     qLayout
     qDevice
     qDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (ReshapeImplF
           (NumelF qShape0)
           (LiftTimesMaybe
              (NumelDimF batchDim)
              (LiftTimesMaybe
                 (NumelDimF querySeqDim)
                 (LiftTimesMaybe
                    (NumelDimF headDim)
                    (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
           qShape0
           ('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]))))
q forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator qGeneratorOutputDevice)
  (Generator kGeneratorOutputDevice)
  (Tensor
     qRequiresGradient
     kLayout
     kDevice
     kDataType
     (TransposeF
        ('SelectDim ('ByIndex 2))
        ('SelectDim ('ByIndex 3))
        (TransposeF
           ('SelectDim ('ByIndex 1))
           ('SelectDim ('ByIndex 2))
           (ReshapeImplF
              (NumelF kShape0)
              (LiftTimesMaybe
                 (NumelDimF batchDim)
                 (LiftTimesMaybe
                    (NumelDimF keySeqDim)
                    (LiftTimesMaybe
                       (NumelDimF headDim)
                       (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
              kShape0
              ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim])))))
kt
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (gradient' :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (layout' :: Layout LayoutType)
       (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (dataType' :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ MatmulF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
matmul
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
                forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
                      MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                      MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                      MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
                  )
                  MultiHeadAttentionHasScaling
mhaScaling
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`add` Tensor
  attentionBiasRequiresGradient
  attentionBiasLayout
  attentionBiasDevice
  attentionBiasDataType
  attentionBiasShape
attentionBias)
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
softmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @3))
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward dropout
mhaDropout
          v :: IxStateT
  m
  (Generator weightsGeneratorOutputDevice)
  (Generator vGeneratorOutputDevice)
  (Tensor
     weightsRequiresGradient
     vLayout
     vDevice
     vDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (ReshapeImplF
           (NumelF vShape0)
           (LiftTimesMaybe
              (NumelDimF batchDim)
              (LiftTimesMaybe
                 (NumelDimF keySeqDim)
                 (LiftTimesMaybe
                    (NumelDimF headDim)
                    (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
           vShape0
           ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
v =
            forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  valueRequiresGradient
  valueLayout
  valueDevice
  valueDataType
  valueShape
value
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward vInProj
mhaVInProj
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim keySeqDim
keySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headDim
mhaHeadDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headEmbedDim
mhaHeadEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
              forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
       in (,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator weightsGeneratorOutputDevice)
  (Tensor
     weightsRequiresGradient
     weightsLayout
     weightsDevice
     weightsDataType
     weightsShape)
weights forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator weightsGeneratorOutputDevice)
  (Generator vGeneratorOutputDevice)
  (Tensor
     weightsRequiresGradient
     vLayout
     vDevice
     vDataType
     (TransposeF
        ('SelectDim ('ByIndex 1))
        ('SelectDim ('ByIndex 2))
        (ReshapeImplF
           (NumelF vShape0)
           (LiftTimesMaybe
              (NumelDimF batchDim)
              (LiftTimesMaybe
                 (NumelDimF keySeqDim)
                 (LiftTimesMaybe
                    (NumelDimF headDim)
                    (LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
           vShape0
           ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
v
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (gradient' :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (layout' :: Layout LayoutType)
       (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (dataType' :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ MatmulF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
matmul
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim querySeqDim
querySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
mhaEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outProj
mhaOutProj