{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.GCrossAttention where

import Control.Monad.Indexed (IxPointed (ireturn), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention (GMultiHeadAttentionF, multiHeadAttentionSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, SShape (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

-- | Generic cross-attention layer data type.
--
-- - @initialLayerNorm@: the initial layer normalization
-- - @mha@: the multi-headed attention layer
-- - @dropout@: the dropout layer
-- - @finalLayerNorm@: the final layer normalization
data
  GCrossAttention
    (initialLayerNorm :: Type)
    (mha :: Type)
    (dropout :: Type)
    (finalLayerNorm :: Type)
  where
  GCrossAttention ::
    forall initialLayerNorm mha dropout finalLayerNorm.
    { -- | initial layer normalization of the cross-attention layer.
      forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> initialLayerNorm
caInitialLayerNorm :: initialLayerNorm,
      -- | multi-headed attention layer specialized for cross-attention.
      forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> mha
caMultiHeadAttention :: mha,
      -- | dropout
      forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> dropout
caDropout :: dropout,
      -- | final layer normalization of the cross-attention layer.
      forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> finalLayerNorm
caFinalLayerNorm :: finalLayerNorm
    } ->
    GCrossAttention initialLayerNorm mha dropout finalLayerNorm
  deriving stock (GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall initialLayerNorm mha dropout finalLayerNorm.
(Eq initialLayerNorm, Eq mha, Eq dropout, Eq finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
/= :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c/= :: forall initialLayerNorm mha dropout finalLayerNorm.
(Eq initialLayerNorm, Eq mha, Eq dropout, Eq finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
== :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c== :: forall initialLayerNorm mha dropout finalLayerNorm.
(Eq initialLayerNorm, Eq mha, Eq dropout, Eq finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
Eq, GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {initialLayerNorm} {mha} {dropout} {finalLayerNorm}.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
Eq (GCrossAttention initialLayerNorm mha dropout finalLayerNorm)
forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
min :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
$cmin :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
max :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
$cmax :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
>= :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c>= :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
> :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c> :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
<= :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c<= :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
< :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c< :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
compare :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
$ccompare :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
Ord, Int
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
 Show finalLayerNorm) =>
Int
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
 Show finalLayerNorm) =>
[GCrossAttention initialLayerNorm mha dropout finalLayerNorm]
-> ShowS
forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
 Show finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> String
showList :: [GCrossAttention initialLayerNorm mha dropout finalLayerNorm]
-> ShowS
$cshowList :: forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
 Show finalLayerNorm) =>
[GCrossAttention initialLayerNorm mha dropout finalLayerNorm]
-> ShowS
show :: GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> String
$cshow :: forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
 Show finalLayerNorm) =>
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> String
showsPrec :: Int
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
$cshowsPrec :: forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
 Show finalLayerNorm) =>
Int
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall initialLayerNorm mha dropout finalLayerNorm x.
Rep (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) x
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
forall initialLayerNorm mha dropout finalLayerNorm x.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Rep
     (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) x
$cto :: forall initialLayerNorm mha dropout finalLayerNorm x.
Rep (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) x
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
$cfrom :: forall initialLayerNorm mha dropout finalLayerNorm x.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> Rep
     (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) x
Generic)

type instance
  ModelSpec (GCrossAttention initialLayerNorm mha dropout finalLayerNorm) =
    GCrossAttention (ModelSpec initialLayerNorm) (ModelSpec mha) (ModelSpec dropout) (ModelSpec finalLayerNorm)

type family
  GCrossAttentionF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  GCrossAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout =
    GCrossAttention
      (CAInitialLayerNormF style gradient device dataType queryEmbedDim)
      (CAMultiheadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout)
      (CADropoutF style hasDropout)
      (CAFinalLayerNormF style gradient device dataType queryEmbedDim)

-- | Specifies the initial layer normalization of the cross-attention layer.
type family
  CAInitialLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  CAInitialLayerNormF 'T5 gradient device dataType queryEmbedDim =
    NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[queryEmbedDim]))
  CAInitialLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
    CAInitialLayerNormF 'T5 gradient device dataType queryEmbedDim
  CAInitialLayerNormF 'BART _ _ _ _ =
    ()
  CAInitialLayerNormF 'MBART gradient device dataType queryEmbedDim =
    CAInitialLayerNormF 'BART gradient device dataType queryEmbedDim
  CAInitialLayerNormF 'Pegasus gradient device dataType queryEmbedDim =
    NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))

-- | Specifies the multi-headed attention layer specialized for cross-attention.
type family
  CAMultiheadAttentionF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  CAMultiheadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout =
    NamedModel (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim keyEmbedDim hasDropout)

-- | Specifies the dropout layer of the cross-attention layer.
type family
  CADropoutF
    (style :: TransformerStyle)
    (hasDropout :: HasDropout) ::
    Type
  where
  CADropoutF _ 'WithDropout = Dropout
  CADropoutF _ 'WithoutDropout = ()

-- | Specifies the final layer normalization of the cross-attention layer.
type family
  CAFinalLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  CAFinalLayerNormF 'T5 _ _ _ _ =
    ()
  CAFinalLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
    CAFinalLayerNormF 'T5 gradient device dataType queryEmbedDim
  CAFinalLayerNormF 'BART gradient device dataType queryEmbedDim =
    NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
  CAFinalLayerNormF 'MBART gradient device dataType queryEmbedDim =
    CAFinalLayerNormF 'BART gradient device dataType queryEmbedDim
  CAFinalLayerNormF 'Pegasus gradient device dataType queryEmbedDim =
    ()

-- | Specifies the parameters of a cross-attention layer.
--
-- - @style@: the style of the transformer stack, e.g. 'ST5', 'SByT5', etc.
-- - @gradient@: whether to compute the gradient of the stack's parameters.
-- - @device@: the computational device on which the stack is allocated.
-- - @dataType@: the data type of the stack's parameters.
-- - @headDim@: the dimension of all transformer heads in the stack.
-- - @headEmbedDim@: the dimension of the transformer head embeddings.
-- - @embedDim@: the dimension of the transformer embeddings.
-- - @queryEmbedDim@: the dimension of the transformer query embeddings.
-- - @keyEmbedDim@: the dimension of the transformer key embeddings.
-- - @dropoutP@: the dropout rate.
-- - @eps@: the epsilon value for numerical stability of the layer normalization.
crossAttentionSpec ::
  forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout.
  STransformerStyle style ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim queryEmbedDim ->
  SDim keyEmbedDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (GCrossAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout)
crossAttentionSpec :: forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GCrossAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        hasDropout)
crossAttentionSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let initialLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (CAInitialLayerNormF style gradient device dataType queryEmbedDim)
initialLayerNormSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." LayerNormSpec
  'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
      initialLayerNormSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." LayerNormSpec
  'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
      initialLayerNormSpec STransformerStyle style
SBART = ()
      initialLayerNormSpec STransformerStyle style
SMBART = ()
      initialLayerNormSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder_attn_layer_norm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      initialLayerNormSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      initialLayerNormSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      initialLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      mhaSpec :: STransformerStyle style
-> NamedModel
     (GMultiHeadAttention
        headDim
        headEmbedDim
        embedDim
        (ModelSpec
           (QInProjF style gradient device dataType queryEmbedDim embedDim))
        (ModelSpec
           (KInProjF style gradient device dataType keyEmbedDim embedDim))
        (ModelSpec
           (VInProjF style gradient device dataType keyEmbedDim embedDim))
        (ModelSpec
           (OutProjF style gradient device dataType embedDim queryEmbedDim))
        (ModelSpec (DropoutF style hasDropout)))
mhaSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"EncDecAttention." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        keyEmbedDim
        hasDropout)
mhaSpec' STransformerStyle 'T5
ST5
      mhaSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"EncDecAttention." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        keyEmbedDim
        hasDropout)
mhaSpec' STransformerStyle 'ByT5
SByT5
      mhaSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder_attn." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        keyEmbedDim
        hasDropout)
mhaSpec' STransformerStyle 'BART
SBART
      mhaSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder_attn." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        keyEmbedDim
        hasDropout)
mhaSpec' STransformerStyle 'MBART
SMBART
      mhaSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder_attn." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        keyEmbedDim
        hasDropout)
mhaSpec' STransformerStyle 'Pegasus
SPegasus
      mhaSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      mhaSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      mhaSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      dropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (CADropoutF style hasDropout)
dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
      dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithoutDropout = ()
      finalLayerNormSpec :: STransformerStyle style
-> ModelSpec
     (CAFinalLayerNormF style gradient device dataType queryEmbedDim)
finalLayerNormSpec STransformerStyle style
ST5 = ()
      finalLayerNormSpec STransformerStyle style
SByT5 = ()
      finalLayerNormSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder_attn_layer_norm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      finalLayerNormSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder_attn_layer_norm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
      finalLayerNormSpec STransformerStyle style
SPegasus = ()
      finalLayerNormSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      finalLayerNormSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      finalLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall initialLayerNorm mha dropout finalLayerNorm.
initialLayerNorm
-> mha
-> dropout
-> finalLayerNorm
-> GCrossAttention initialLayerNorm mha dropout finalLayerNorm
GCrossAttention (STransformerStyle style
-> ModelSpec
     (CAInitialLayerNormF style gradient device dataType queryEmbedDim)
initialLayerNormSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
     (GMultiHeadAttention
        headDim
        headEmbedDim
        embedDim
        (ModelSpec
           (QInProjF style gradient device dataType queryEmbedDim embedDim))
        (ModelSpec
           (KInProjF style gradient device dataType keyEmbedDim embedDim))
        (ModelSpec
           (VInProjF style gradient device dataType keyEmbedDim embedDim))
        (ModelSpec
           (OutProjF style gradient device dataType embedDim queryEmbedDim))
        (ModelSpec (DropoutF style hasDropout)))
mhaSpec STransformerStyle style
style) (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (CADropoutF style hasDropout)
dropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout) (STransformerStyle style
-> ModelSpec
     (CAFinalLayerNormF style gradient device dataType queryEmbedDim)
finalLayerNormSpec STransformerStyle style
style)
  where
    mhaSpec' ::
      STransformerStyle style ->
      ModelSpec (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim keyEmbedDim hasDropout)
    mhaSpec' :: STransformerStyle style
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        keyEmbedDim
        hasDropout)
mhaSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
       (valueEmbedDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim valueEmbedDim
-> SHasDropout hasDropout
-> Double
-> ModelSpec
     (GMultiHeadAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        valueEmbedDim
        hasDropout)
multiHeadAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim keyEmbedDim
keyEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP
    layerNormWithoutBiasSpec :: LayerNormSpec
  'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
    layerNormWithBiasSpec :: LayerNormSpec
  'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps

instance
  ( HasInitialize initialLayerNorm generatorDevice initialLayerNorm' generatorDevice0,
    HasInitialize multiHeadAttention generatorDevice0 multiHeadAttention' generatorDevice1,
    HasInitialize dropout generatorDevice1 dropout' generatorDevice2,
    HasInitialize finalLayerNorm generatorDevice2 finalLayerNorm' generatorOutputDevice
  ) =>
  HasInitialize
    (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm)
    generatorDevice
    (GCrossAttention initialLayerNorm' multiHeadAttention' dropout' finalLayerNorm')
    generatorOutputDevice

instance
  ( HasStateDict initialLayerNorm,
    HasStateDict multiHeadAttention,
    HasStateDict dropout,
    HasStateDict finalLayerNorm
  ) =>
  HasStateDict (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm)

-- | 'HasForward' instance for 'GCrossAttention'.
--
-- @
--        ┌───────┐    ┌─────┐    ┌───────────────┐
--        │ query │    │ key │    │ attentionBias │
--        └───┬───┘    └──┬──┘    └───────┬───────┘
--            │           │               │
-- ┌──────────┤           │               │
-- │          │           │               │
-- │          ▼           │               │
-- │ (caInitialLayerNorm) │               │
-- │          │           │               │
-- │          │       ┌───┴───┐           │
-- │          │       │       │           │
-- │          ▼       ▼       ▼           │
-- │        caMultiheadAttention◄─────────┘
-- │                  │
-- │                  ▼
-- │              caDropout
-- │                  │
-- └──────►add◄───────┘
--          │
--          ▼
--  (caFinalLayerNorm)
--          │
--          ▼
--      ┌───────┐
--      │ query │
--      └───────┘
-- @
instance
  ( HasForward
      initialLayerNorm
      (Tensor queryGradient queryLayout queryDevice queryDataType queryShape)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      multiHeadAttention
      ( tensor0,
        Tensor keyGradient keyLayout keyDevice keyDataType keyShape,
        Tensor keyGradient keyLayout keyDevice keyDataType keyShape,
        Tensor attentionBiasGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape
      )
      generatorDevice0
      tensor1
      generatorDevice1,
    HasForward
      dropout
      tensor1
      generatorDevice1
      (Tensor gradient2 layout2 device2 dataType2 shape2)
      generatorDevice2,
    HasForward
      finalLayerNorm
      (Tensor (queryGradient <|> gradient2) (queryLayout <+> layout2) (queryDevice <+> device2) (queryDataType <+> dataType2) (BroadcastShapesF queryShape shape2))
      generatorDevice2
      output
      generatorOutputDevice,
    Catch (BroadcastShapesF queryShape shape2)
  ) =>
  HasForward
    (GCrossAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm)
    ( Tensor queryGradient queryLayout queryDevice queryDataType queryShape,
      Tensor keyGradient keyLayout keyDevice keyDataType keyShape,
      Tensor attentionBiasGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape
    )
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GCrossAttention
  initialLayerNorm multiHeadAttention dropout finalLayerNorm
-> (Tensor
      queryGradient queryLayout queryDevice queryDataType queryShape,
    Tensor keyGradient keyLayout keyDevice keyDataType keyShape,
    Tensor
      attentionBiasGradient
      attentionBiasLayout
      attentionBiasDevice
      attentionBiasDataType
      attentionBiasShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GCrossAttention {initialLayerNorm
multiHeadAttention
dropout
finalLayerNorm
caFinalLayerNorm :: finalLayerNorm
caDropout :: dropout
caMultiHeadAttention :: multiHeadAttention
caInitialLayerNorm :: initialLayerNorm
caFinalLayerNorm :: forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> finalLayerNorm
caDropout :: forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> dropout
caMultiHeadAttention :: forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm -> mha
caInitialLayerNorm :: forall initialLayerNorm mha dropout finalLayerNorm.
GCrossAttention initialLayerNorm mha dropout finalLayerNorm
-> initialLayerNorm
..} (Tensor
  queryGradient queryLayout queryDevice queryDataType queryShape
query, Tensor keyGradient keyLayout keyDevice keyDataType keyShape
key, Tensor
  attentionBiasGradient
  attentionBiasLayout
  attentionBiasDevice
  attentionBiasDataType
  attentionBiasShape
attentionBias) =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
  queryGradient queryLayout queryDevice queryDataType queryShape
query
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
caInitialLayerNorm
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\tensor0
query' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward multiHeadAttention
caMultiHeadAttention (tensor0
query', Tensor keyGradient keyLayout keyDevice keyDataType keyShape
key, Tensor keyGradient keyLayout keyDevice keyDataType keyShape
key, Tensor
  attentionBiasGradient
  attentionBiasLayout
  attentionBiasDevice
  attentionBiasDataType
  attentionBiasShape
attentionBias))
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward dropout
caDropout
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
  queryGradient queryLayout queryDevice queryDataType queryShape
query forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`add`)
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
caFinalLayerNorm