{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}

module Torch.GraduallyTyped.NN.Transformer.GBlock where

import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Transformer.GCrossAttention (GCrossAttentionF, crossAttentionSpec)
import Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork (GTransformerFeedForwardNetworkF, transformerFeedForwardNetworkSpec)
import Torch.GraduallyTyped.NN.Transformer.GSelfAttention (GSelfAttentionF, selfAttentionSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Size (..))

-- | Generic transformer encoder block consisting of self-attention, cross-attention, and a feed-forward network.
--
-- - @selfAttention@ is a self-attention layer.
-- - @crossAttention@ is a cross-attention layer.
-- - @feedForwardNetwork@ is a feed-forward layer.
--
-- TODO: Some transformers use LayerDrop, see https://arxiv.org/abs/1909.11556, during training.
-- To support this, we will need a layer wrapper that is either the identity function or the wrapped layer
-- based on a uniformly random draw from a supplied generator.
data
  GTransformerBlock
    (selfAttention :: Type)
    (crossAttention :: Type)
    (feedForwardNetwork :: Type)
  where
  GTransformerBlock ::
    forall selfAttention crossAttention feedForwardNetwork.
    { -- | self-attention layer
      forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> selfAttention
tbSelfAttention :: selfAttention,
      -- | cross-attention layer
      forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> crossAttention
tbCrossAttention :: crossAttention,
      -- | feed-forward network
      forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> feedForwardNetwork
tbFeedForwardNetwork :: feedForwardNetwork
    } ->
    GTransformerBlock selfAttention crossAttention feedForwardNetwork
  deriving stock (GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall selfAttention crossAttention feedForwardNetwork.
(Eq selfAttention, Eq crossAttention, Eq feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
/= :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
$c/= :: forall selfAttention crossAttention feedForwardNetwork.
(Eq selfAttention, Eq crossAttention, Eq feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
== :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
$c== :: forall selfAttention crossAttention feedForwardNetwork.
(Eq selfAttention, Eq crossAttention, Eq feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
Eq, GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {selfAttention} {crossAttention} {feedForwardNetwork}.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
Eq
  (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Ordering
forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
min :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
$cmin :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
max :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
$cmax :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
>= :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
$c>= :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
> :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
$c> :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
<= :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
$c<= :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
< :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
$c< :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Bool
compare :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Ordering
$ccompare :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> Ordering
Ord, Int
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
 Show feedForwardNetwork) =>
Int
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> ShowS
forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
 Show feedForwardNetwork) =>
[GTransformerBlock selfAttention crossAttention feedForwardNetwork]
-> ShowS
forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
 Show feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> String
showList :: [GTransformerBlock selfAttention crossAttention feedForwardNetwork]
-> ShowS
$cshowList :: forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
 Show feedForwardNetwork) =>
[GTransformerBlock selfAttention crossAttention feedForwardNetwork]
-> ShowS
show :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> String
$cshow :: forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
 Show feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> String
showsPrec :: Int
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> ShowS
$cshowsPrec :: forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
 Show feedForwardNetwork) =>
Int
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall selfAttention crossAttention feedForwardNetwork x.
Rep
  (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
  x
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
forall selfAttention crossAttention feedForwardNetwork x.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> Rep
     (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
     x
$cto :: forall selfAttention crossAttention feedForwardNetwork x.
Rep
  (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
  x
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
$cfrom :: forall selfAttention crossAttention feedForwardNetwork x.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> Rep
     (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
     x
Generic)

type instance
  ModelSpec (GTransformerBlock selfAttention crossAttention feedForwardNetwork) =
    GTransformerBlock (ModelSpec selfAttention) (ModelSpec crossAttention) (ModelSpec feedForwardNetwork)

type family
  EncoderBlockF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout)
  where
  EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout =
    GTransformerBlock
      (NamedModel (GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout))
      ()
      (NamedModel (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout))

encoderBlockSpec ::
  forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout.
  STransformerStyle style ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim queryEmbedDim ->
  SDim ffnDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout)
encoderBlockSpec :: forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (EncoderBlockF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        ffnDim
        hasDropout)
encoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let saSpec :: STransformerStyle style
-> NamedModel
     (GSelfAttention
        (ModelSpec
           (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
        (NamedModel
           (GMultiHeadAttention
              headDim
              headEmbedDim
              embedDim
              (ModelSpec
                 (QInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (KInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (VInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (OutProjF style gradient device dataType embedDim queryEmbedDim))
              (ModelSpec (DropoutF style hasDropout))))
        (ModelSpec (SADropoutF style hasDropout))
        (ModelSpec
           (SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'T5
ST5
      saSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'ByT5
SByT5
      saSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'BART
SBART
      saSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'MBART
SMBART
      saSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'Pegasus
SPegasus
      saSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"attention." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'BERT
SBERT
      saSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"attention." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'RoBERTa
SRoBERTa
      saSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      caSpec :: p -> ()
caSpec p
_ = ()
      ffnSpec :: STransformerStyle style
-> NamedModel
     (GTransformerFeedForwardNetwork
        (ModelSpec
           (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
        (ModelSpec
           (FFNInputTransformationF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNActivationF style))
        (ModelSpec (FFNActivationDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputProjectionF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNOutputDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputLayerNormF
              style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'T5
ST5
      ffnSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'ByT5
SByT5
      ffnSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'BART
SBART
      ffnSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'MBART
SMBART
      ffnSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'Pegasus
SPegasus
      ffnSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'BERT
SBERT
      ffnSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'RoBERTa
SRoBERTa
      ffnSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall selfAttention crossAttention feedForwardNetwork.
selfAttention
-> crossAttention
-> feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
GTransformerBlock (STransformerStyle style
-> NamedModel
     (GSelfAttention
        (ModelSpec
           (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
        (NamedModel
           (GMultiHeadAttention
              headDim
              headEmbedDim
              embedDim
              (ModelSpec
                 (QInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (KInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (VInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (OutProjF style gradient device dataType embedDim queryEmbedDim))
              (ModelSpec (DropoutF style hasDropout))))
        (ModelSpec (SADropoutF style hasDropout))
        (ModelSpec
           (SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
style) (forall {p}. p -> ()
caSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
     (GTransformerFeedForwardNetwork
        (ModelSpec
           (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
        (ModelSpec
           (FFNInputTransformationF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNActivationF style))
        (ModelSpec (FFNActivationDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputProjectionF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNOutputDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputLayerNormF
              style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
style)
  where
    saSpec' :: _
    saSpec' :: STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GSelfAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        hasDropout)
selfAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    ffnSpec' :: _
    ffnSpec' :: STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GTransformerFeedForwardNetworkF
        style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps

type family
  DecoderBlockF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout)
  where
  DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout =
    GTransformerBlock
      (NamedModel (GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout))
      (NamedModel (GCrossAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout))
      (NamedModel (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout))

decoderBlockSpec ::
  forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout.
  STransformerStyle style ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim queryEmbedDim ->
  SDim keyEmbedDim ->
  SDim ffnDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout)
decoderBlockSpec :: forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (DecoderBlockF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        ffnDim
        hasDropout)
decoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let saSpec :: STransformerStyle style
-> NamedModel
     (GSelfAttention
        (ModelSpec
           (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
        (NamedModel
           (GMultiHeadAttention
              headDim
              headEmbedDim
              embedDim
              (ModelSpec
                 (QInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (KInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (VInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (OutProjF style gradient device dataType embedDim queryEmbedDim))
              (ModelSpec (DropoutF style hasDropout))))
        (ModelSpec (SADropoutF style hasDropout))
        (ModelSpec
           (SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'T5
ST5
      saSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'ByT5
SByT5
      saSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'BART
SBART
      saSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'MBART
SMBART
      saSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'Pegasus
SPegasus
      saSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      saSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      saSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      caSpec :: STransformerStyle style
-> NamedModel
     (GCrossAttention
        (ModelSpec
           (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
        (NamedModel
           (GMultiHeadAttention
              headDim
              headEmbedDim
              embedDim
              (ModelSpec
                 (QInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (KInProjF style gradient device dataType keyEmbedDim embedDim))
              (ModelSpec
                 (VInProjF style gradient device dataType keyEmbedDim embedDim))
              (ModelSpec
                 (OutProjF style gradient device dataType embedDim queryEmbedDim))
              (ModelSpec (DropoutF style hasDropout))))
        (ModelSpec (CADropoutF style hasDropout))
        (ModelSpec
           (CAFinalLayerNormF style gradient device dataType queryEmbedDim)))
caSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
     (ModelSpec
        (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (CADropoutF style hasDropout))
     (ModelSpec
        (CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'T5
ST5
      caSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
     (ModelSpec
        (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (CADropoutF style hasDropout))
     (ModelSpec
        (CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'ByT5
SByT5
      caSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
     (ModelSpec
        (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (CADropoutF style hasDropout))
     (ModelSpec
        (CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'BART
SBART
      caSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
     (ModelSpec
        (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (CADropoutF style hasDropout))
     (ModelSpec
        (CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'MBART
SMBART
      caSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
     (ModelSpec
        (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (CADropoutF style hasDropout))
     (ModelSpec
        (CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'Pegasus
SPegasus
      caSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      caSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      caSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      ffnSpec :: STransformerStyle style
-> NamedModel
     (GTransformerFeedForwardNetwork
        (ModelSpec
           (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
        (ModelSpec
           (FFNInputTransformationF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNActivationF style))
        (ModelSpec (FFNActivationDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputProjectionF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNOutputDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputLayerNormF
              style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.2." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'T5
ST5
      ffnSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.2." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'ByT5
SByT5
      ffnSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'BART
SBART
      ffnSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'MBART
SMBART
      ffnSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'Pegasus
SPegasus
      ffnSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
      ffnSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
      ffnSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall selfAttention crossAttention feedForwardNetwork.
selfAttention
-> crossAttention
-> feedForwardNetwork
-> GTransformerBlock
     selfAttention crossAttention feedForwardNetwork
GTransformerBlock (STransformerStyle style
-> NamedModel
     (GSelfAttention
        (ModelSpec
           (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
        (NamedModel
           (GMultiHeadAttention
              headDim
              headEmbedDim
              embedDim
              (ModelSpec
                 (QInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (KInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (VInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (OutProjF style gradient device dataType embedDim queryEmbedDim))
              (ModelSpec (DropoutF style hasDropout))))
        (ModelSpec (SADropoutF style hasDropout))
        (ModelSpec
           (SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
     (GCrossAttention
        (ModelSpec
           (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
        (NamedModel
           (GMultiHeadAttention
              headDim
              headEmbedDim
              embedDim
              (ModelSpec
                 (QInProjF style gradient device dataType queryEmbedDim embedDim))
              (ModelSpec
                 (KInProjF style gradient device dataType keyEmbedDim embedDim))
              (ModelSpec
                 (VInProjF style gradient device dataType keyEmbedDim embedDim))
              (ModelSpec
                 (OutProjF style gradient device dataType embedDim queryEmbedDim))
              (ModelSpec (DropoutF style hasDropout))))
        (ModelSpec (CADropoutF style hasDropout))
        (ModelSpec
           (CAFinalLayerNormF style gradient device dataType queryEmbedDim)))
caSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
     (GTransformerFeedForwardNetwork
        (ModelSpec
           (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
        (ModelSpec
           (FFNInputTransformationF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNActivationF style))
        (ModelSpec (FFNActivationDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputProjectionF
              style gradient device dataType queryEmbedDim ffnDim))
        (ModelSpec (FFNOutputDropoutF style hasDropout))
        (ModelSpec
           (FFNOutputLayerNormF
              style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
style)
  where
    saSpec' :: _
    saSpec' :: STransformerStyle style
-> GSelfAttention
     (ModelSpec
        (SAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (SADropoutF style hasDropout))
     (ModelSpec
        (SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GSelfAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        hasDropout)
selfAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    caSpec' :: _
    caSpec' :: STransformerStyle style
-> GCrossAttention
     (ModelSpec
        (CAInitialLayerNormF style gradient device dataType queryEmbedDim))
     (NamedModel
        (GMultiHeadAttention
           headDim
           headEmbedDim
           embedDim
           (ModelSpec
              (QInProjF style gradient device dataType queryEmbedDim embedDim))
           (ModelSpec
              (KInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (VInProjF style gradient device dataType keyEmbedDim embedDim))
           (ModelSpec
              (OutProjF style gradient device dataType embedDim queryEmbedDim))
           (ModelSpec (DropoutF style hasDropout))))
     (ModelSpec (CADropoutF style hasDropout))
     (ModelSpec
        (CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Nat))
       (headEmbedDim :: Dim (Name Symbol) (Size Nat))
       (embedDim :: Dim (Name Symbol) (Size Nat))
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GCrossAttentionF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        hasDropout)
crossAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
    ffnSpec' :: _
    ffnSpec' :: STransformerStyle style
-> GTransformerFeedForwardNetwork
     (ModelSpec
        (FFNInputLayerNormF style gradient device dataType queryEmbedDim))
     (ModelSpec
        (FFNInputTransformationF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNActivationF style))
     (ModelSpec (FFNActivationDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputProjectionF
           style gradient device dataType queryEmbedDim ffnDim))
     (ModelSpec (FFNOutputDropoutF style hasDropout))
     (ModelSpec
        (FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
       (ffnDim :: Dim (Name Symbol) (Size Nat))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GTransformerFeedForwardNetworkF
        style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps

instance
  ( HasInitialize selfAttention generatorDevice selfAttention' generatorDevice0,
    HasInitialize crossAttention generatorDevice0 crossAttention' generatorDevice1,
    HasInitialize feedForwardNetwork generatorDevice1 feedForwardNetwork' generatorOutputDevice
  ) =>
  HasInitialize
    (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
    generatorDevice
    (GTransformerBlock selfAttention' crossAttention' feedForwardNetwork')
    generatorOutputDevice

instance
  ( HasStateDict selfAttention,
    HasStateDict crossAttention,
    HasStateDict feedForwardNetwork
  ) =>
  HasStateDict (GTransformerBlock selfAttention crossAttention feedForwardNetwork)

-- | 'HasForward' instance for 'GTransformerBlock' in an encoder configuration.
--
-- @
--      ┌───────┐  ┌───────────────┐
--      │ query │  │ attentionBias │
--      └───┬───┘  └───────┬───────┘
--          │              │
--          ▼              │
--   tbSelfAttention◄──────┘
--          ▼
-- tbFeedForwardNetwork
--          │
--          ▼
--      ┌───────┐
--      │ query │
--      └───────┘
-- @
instance
  ( HasForward
      selfAttention
      (query, attentionBias)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      feedForwardNetwork
      tensor0
      generatorDevice0
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformerBlock selfAttention () feedForwardNetwork)
    (query, attentionBias)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerBlock selfAttention () feedForwardNetwork
-> (query, attentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformerBlock {selfAttention
feedForwardNetwork
()
tbFeedForwardNetwork :: feedForwardNetwork
tbCrossAttention :: ()
tbSelfAttention :: selfAttention
tbFeedForwardNetwork :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> feedForwardNetwork
tbCrossAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> crossAttention
tbSelfAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> selfAttention
..} (query
query, attentionBias
attentionBias) =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (query
query, attentionBias
attentionBias)
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward selfAttention
tbSelfAttention
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward feedForwardNetwork
tbFeedForwardNetwork

-- | 'HasForward' instance for 'GTransformerBlock' in a decoder configuration.
--
-- @
-- ┌──────────────────────┐  ┌───────┐  ┌─────┐  ┌────────────────────┐
-- │ decoderAttentionBias │  │ query │  │ key │  │ crossAttentionBias │
-- └──────────┬───────────┘  └───┬───┘  └──┬──┘  └─────────┬──────────┘
--            │                  │         │               │
--            │                  ▼         │               │
--            └──────────►tdbSelfAttention │               │
--                               │         │               │
--                               ▼         ▼               │
--                            tdbCrossAttention◄───────────┘
--                               │
--                               ▼
--                     tdbFeedForwardNetwork
--                               │
--                               ▼
--                           ┌───────┐
--                           │ query │
--                           └───────┘
-- @
instance
  ( HasForward
      selfAttention
      (query, attentionBias)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      crossAttention
      (tensor0, key, crossAttentionBias)
      generatorDevice0
      tensor1
      generatorDevice1,
    HasForward
      feedForwardNetwork
      tensor1
      generatorDevice1
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
    (query, key, attentionBias, crossAttentionBias)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformerBlock {selfAttention
crossAttention
feedForwardNetwork
tbFeedForwardNetwork :: feedForwardNetwork
tbCrossAttention :: crossAttention
tbSelfAttention :: selfAttention
tbFeedForwardNetwork :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> feedForwardNetwork
tbCrossAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> crossAttention
tbSelfAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> selfAttention
..} (query
query, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (query
query, attentionBias
attentionBias)
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward selfAttention
tbSelfAttention
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\tensor0
query' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward crossAttention
tbCrossAttention forall a b. (a -> b) -> a -> b
$ (tensor0
query', key
key, crossAttentionBias
crossAttentionBias))
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward feedForwardNetwork
tbFeedForwardNetwork