{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.GStack where

import Control.Monad.Indexed.State (IxStateT (..))
import Data.Functor.Indexed ((<<$>>))
import Data.Kind (Type)
import qualified Data.Vector as V hiding (uncons)
import qualified Data.Vector.Generic.Sized.Internal as VGS
import qualified Data.Vector.Sized as VS
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol, type (+))
import Torch.GraduallyTyped.DType (DType, DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice (..))
import qualified Torch.GraduallyTyped.Internal.Vector as V
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, VectorSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GBlock (DecoderBlockF, EncoderBlockF, decoderBlockSpec, encoderBlockSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle, TransformerStyle)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim, Name, SDim, Size)

-- | Generic transformer stack.
--
-- - @stack@ is a stack of tranformer blocks.
newtype GTransformerStack (stack :: Type) where
  GTransformerStack :: forall stack. stack -> GTransformerStack stack
  deriving stock (GTransformerStack stack -> GTransformerStack stack -> Bool
forall stack.
Eq stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c/= :: forall stack.
Eq stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
== :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c== :: forall stack.
Eq stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
Eq, GTransformerStack stack -> GTransformerStack stack -> Bool
GTransformerStack stack -> GTransformerStack stack -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {stack}. Ord stack => Eq (GTransformerStack stack)
forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Ordering
forall stack.
Ord stack =>
GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
min :: GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
$cmin :: forall stack.
Ord stack =>
GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
max :: GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
$cmax :: forall stack.
Ord stack =>
GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
>= :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c>= :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
> :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c> :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
<= :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c<= :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
< :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c< :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
compare :: GTransformerStack stack -> GTransformerStack stack -> Ordering
$ccompare :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Ordering
Ord, Int -> GTransformerStack stack -> ShowS
forall stack. Show stack => Int -> GTransformerStack stack -> ShowS
forall stack. Show stack => [GTransformerStack stack] -> ShowS
forall stack. Show stack => GTransformerStack stack -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GTransformerStack stack] -> ShowS
$cshowList :: forall stack. Show stack => [GTransformerStack stack] -> ShowS
show :: GTransformerStack stack -> String
$cshow :: forall stack. Show stack => GTransformerStack stack -> String
showsPrec :: Int -> GTransformerStack stack -> ShowS
$cshowsPrec :: forall stack. Show stack => Int -> GTransformerStack stack -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall stack x.
Rep (GTransformerStack stack) x -> GTransformerStack stack
forall stack x.
GTransformerStack stack -> Rep (GTransformerStack stack) x
$cto :: forall stack x.
Rep (GTransformerStack stack) x -> GTransformerStack stack
$cfrom :: forall stack x.
GTransformerStack stack -> Rep (GTransformerStack stack) x
Generic)

type instance
  ModelSpec (GTransformerStack stack) =
    GTransformerStack (ModelSpec stack)

type family
  EncoderStackF
    (style :: TransformerStyle)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout)
  where
  EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout =
    GTransformerStack
      ( VS.Vector
          numLayers
          (EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout)
      )

-- | Specifies the parameters of a transformer stack in an encoder configuration.
--
-- - @style@: the style of the transformer stack, e.g. 'ST5', 'SByT5', etc.
-- - @gradient@: whether to compute the gradient of the stack's parameters.
-- - @device@: the computational device on which the stack is allocated.
-- - @dataType@: the data type of the stack's parameters.
-- - @headDim@: the dimension of all transformer heads in the stack.
-- - @headEmbedDim@: the dimension of the transformer head embeddings.
-- - @embedDim@: the dimension of the transformer embeddings.
-- - @queryEmbedDim@: the dimension of the transformer query embeddings.
-- - @ffnDim@: the dimension of the feed-forward network.
-- - @dropoutP@: the dropout rate.
-- - @eps@: the epsilon value for numerical stability of the layer normalization.
encoderStackSpec ::
  forall style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout.
  STransformerStyle style ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim queryEmbedDim ->
  SDim ffnDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout)
encoderStackSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (queryEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (EncoderStackF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        ffnDim
        hasDropout)
encoderStackSpec STransformerStyle style
style numLayers :: SNat numLayers
numLayers@SNat numLayers
SNat SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let blockSpec :: ModelSpec
  (EncoderBlockF
     style
     gradient
     device
     dataType
     headDim
     headEmbedDim
     embedDim
     queryEmbedDim
     ffnDim
     hasDropout)
blockSpec = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (queryEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (EncoderBlockF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        ffnDim
        hasDropout)
encoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
   in forall stack. stack -> GTransformerStack stack
GTransformerStack forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) a.
SNat n -> Vector n (ModelSpec a) -> VectorSpec n a
VectorSpec SNat numLayers
numLayers (forall (n :: Natural) a (p :: Natural -> *).
KnownNat n =>
p n -> a -> Vector n a
VS.replicate' SNat numLayers
numLayers ModelSpec
  (EncoderBlockF
     style
     gradient
     device
     dataType
     headDim
     headEmbedDim
     embedDim
     queryEmbedDim
     ffnDim
     hasDropout)
blockSpec)

type family
  DecoderStackF
    (style :: TransformerStyle)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (queryEmbedDim :: Dim (Name Symbol) (Size Nat))
    (keyEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout)
  where
  DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout =
    GTransformerStack
      ( VS.Vector
          numLayers
          (DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout)
      )

-- | Specifies the parameters of a transformer stack in a decoder configuration.
--
-- - @style@: the style of the transformer stack, e.g. 'ST5', 'SByT5', etc.
-- - @gradient@: whether to compute the gradient of the stack's parameters.
-- - @device@: the computational device on which the stack is allocated.
-- - @dataType@: the data type of the stack's parameters.
-- - @headDim@: the dimension of all transformer heads in the stack.
-- - @headEmbedDim@: the dimension of the transformer head embeddings.
-- - @embedDim@: the dimension of the transformer embeddings.
-- - @queryEmbedDim@: the dimension of the transformer query embeddings.
-- - @keyEmbedDim@: the dimension of the transformer key embeddings.
-- - @ffnDim@: the dimension of the feed-forward network.
-- - @dropoutP@: the dropout rate.
-- - @eps@: the epsilon value for numerical stability of the layer normalization.
decoderStackSpec ::
  forall style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout.
  STransformerStyle style ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim headDim ->
  SDim headEmbedDim ->
  SDim embedDim ->
  SDim queryEmbedDim ->
  SDim keyEmbedDim ->
  SDim ffnDim ->
  SHasDropout hasDropout ->
  Double ->
  Double ->
  ModelSpec (DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout)
decoderStackSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (queryEmbedDim :: Dim (Name Symbol) (Size Natural))
       (keyEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (DecoderStackF
        style
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        ffnDim
        hasDropout)
decoderStackSpec STransformerStyle style
style numLayers :: SNat numLayers
numLayers@SNat numLayers
SNat SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
  let blockSpec :: ModelSpec
  (DecoderBlockF
     style
     gradient
     device
     dataType
     headDim
     headEmbedDim
     embedDim
     queryEmbedDim
     keyEmbedDim
     ffnDim
     hasDropout)
blockSpec = forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (queryEmbedDim :: Dim (Name Symbol) (Size Natural))
       (keyEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (DecoderBlockF
        style
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        queryEmbedDim
        keyEmbedDim
        ffnDim
        hasDropout)
decoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
   in forall stack. stack -> GTransformerStack stack
GTransformerStack forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) a.
SNat n -> Vector n (ModelSpec a) -> VectorSpec n a
VectorSpec SNat numLayers
numLayers (forall (n :: Natural) a (p :: Natural -> *).
KnownNat n =>
p n -> a -> Vector n a
VS.replicate' SNat numLayers
numLayers ModelSpec
  (DecoderBlockF
     style
     gradient
     device
     dataType
     headDim
     headEmbedDim
     embedDim
     queryEmbedDim
     keyEmbedDim
     ffnDim
     hasDropout)
blockSpec)

instance
  ( HasInitialize block generatorDevice block' generatorDevice,
    numLayers' ~ (numLayers + 1)
  ) =>
  HasInitialize
    (GTransformerStack (VS.Vector numLayers' block))
    generatorDevice
    (GTransformerStack (VS.Vector numLayers' block'))
    generatorDevice

instance
  HasStateDict block =>
  HasStateDict (GTransformerStack (VS.Vector numLayers block))

instance
  HasForward
    (GTransformerStack (VS.Vector 0 block))
    (query, attentionBias)
    generatorDevice
    query
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 0 block)
-> (query, attentionBias)
-> Generator generatorDevice
-> m (query, Generator generatorDevice)
forward GTransformerStack (Vector 0 block)
_ (query
query, attentionBias
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (query
query,)

instance
  HasForward
    block
    (query, attentionBias)
    generatorDevice
    output
    generatorOutputDevice =>
  HasForward
    (GTransformerStack (VS.Vector 1 block))
    (query, attentionBias)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 1 block)
-> (query, attentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query, attentionBias)
input Generator generatorDevice
g =
    let Just (block
block, Vector block
_) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
     in forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query, attentionBias)
input Generator generatorDevice
g

instance
  HasForward
    (GTransformerStack (VS.Vector 0 block))
    (query, key, attentionBias, crossAttentionBias)
    generator
    query
    generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 0 block)
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generator
-> m (query, Generator generator)
forward GTransformerStack (Vector 0 block)
_ (query
query, key
_, attentionBias
_, crossAttentionBias
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (query
query,)

instance
  HasForward
    block
    (query, key, attentionBias, crossAttentionBias)
    generatorDevice
    output
    generatorOutputDevice =>
  HasForward
    (GTransformerStack (VS.Vector 1 block))
    (query, key, attentionBias, crossAttentionBias)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 1 block)
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query, key, attentionBias, crossAttentionBias)
input Generator generatorDevice
g =
    let Just (block
block, Vector block
_) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
     in forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query, key, attentionBias, crossAttentionBias)
input Generator generatorDevice
g

-- | 'HasForward' instance for 'GTransformerStack' in an encoder configuration.
--
-- @
-- ┌───────┐  ┌───────────────┐
-- │ query │  │ attentionBias │
-- └───┬───┘  └───────┬───────┘
--     │              │
--     ▼              │
--   block◄───────────┤
--     ▼              │
--   block◄───────────┤
--     ▼              │
--    ...            ...
--     ▼              │
--   block◄───────────┘
--     │
--     ▼
-- ┌───────┐
-- │ query │
-- └───────┘
-- @
instance
  {-# OVERLAPPABLE #-}
  ( HasForward
      block
      (query, attentionBias)
      generatorDevice
      output
      generatorOutputDevice,
    HasForward
      block
      (output, attentionBias)
      generatorOutputDevice
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformerStack (VS.Vector n block))
    (query, attentionBias)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector n block)
-> (query, attentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query
query, attentionBias
attentionBias) Generator generatorDevice
g =
    let Just (block
block, Vector block
blocks) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
     in forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl
          ( \m (output, Generator generatorOutputDevice)
agg block
block' -> do
              (output
output, Generator generatorOutputDevice
g') <- m (output, Generator generatorOutputDevice)
agg
              (output
output', Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block' (output
output, attentionBias
attentionBias) Generator generatorOutputDevice
g'
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output', Generator generatorOutputDevice
g'')
          )
          ( do
              (output
output, Generator generatorOutputDevice
g') <- forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query
query, attentionBias
attentionBias) Generator generatorDevice
g
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output, Generator generatorOutputDevice
g')
          )
          Vector block
blocks

-- | 'HasForward' instance for 'GTransformerStack' in a decoder configuration.
--
-- @
-- ┌───────┐  ┌─────┐  ┌───────────────┐  ┌────────────────────┐
-- │ query │  │ key │  │ attentionBias │  │ crossAttentionBias │
-- └───┬───┘  └──┬──┘  └───────┬───────┘  └─────────┬──────────┘
--     │         │             │                    │
--     ▼         │             │                    │
--   block◄──────┤◄────────────┤◄───────────────────┤
--     ▼         │             │                    │
--   block◄──────┤◄────────────┤◄───────────────────┤
--     ▼         │             │                    │
--    ...       ...           ...                  ...
--     ▼         │             │                    │
--   block◄──────┘◄────────────┘◄───────────────────┘
--     │
--     ▼
-- ┌───────┐
-- │ query │
-- └───────┘
-- @
instance
  {-# OVERLAPPABLE #-}
  ( HasForward
      block
      (query, key, attentionBias, crossAttentionBias)
      generatorDevice
      output
      generatorOutputDevice,
    HasForward
      block
      (output, key, attentionBias, crossAttentionBias)
      generatorOutputDevice
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GTransformerStack (VS.Vector n block))
    (query, key, attentionBias, crossAttentionBias)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector n block)
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query
query, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) Generator generatorDevice
g =
    let Just (block
block, Vector block
blocks) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
     in forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl
          ( \m (output, Generator generatorOutputDevice)
agg block
block' -> do
              (output
output, Generator generatorOutputDevice
g') <- m (output, Generator generatorOutputDevice)
agg
              (output
output', Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block' (output
output, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) Generator generatorOutputDevice
g'
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output', Generator generatorOutputDevice
g'')
          )
          ( do
              (output
output, Generator generatorOutputDevice
g') <- forall model input (generatorDevice :: Device (DeviceType Natural))
       output (generatorOutputDevice :: Device (DeviceType Natural))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query
query, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) Generator generatorDevice
g
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output, Generator generatorOutputDevice
g')
          )
          Vector block
blocks