{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Transformer.GStack where
import Control.Monad.Indexed.State (IxStateT (..))
import Data.Functor.Indexed ((<<$>>))
import Data.Kind (Type)
import qualified Data.Vector as V hiding (uncons)
import qualified Data.Vector.Generic.Sized.Internal as VGS
import qualified Data.Vector.Sized as VS
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol, type (+))
import Torch.GraduallyTyped.DType (DType, DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice (..))
import qualified Torch.GraduallyTyped.Internal.Vector as V
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, VectorSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GBlock (DecoderBlockF, EncoderBlockF, decoderBlockSpec, encoderBlockSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle, TransformerStyle)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim, Name, SDim, Size)
newtype GTransformerStack (stack :: Type) where
GTransformerStack :: forall stack. stack -> GTransformerStack stack
deriving stock (GTransformerStack stack -> GTransformerStack stack -> Bool
forall stack.
Eq stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c/= :: forall stack.
Eq stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
== :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c== :: forall stack.
Eq stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
Eq, GTransformerStack stack -> GTransformerStack stack -> Bool
GTransformerStack stack -> GTransformerStack stack -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {stack}. Ord stack => Eq (GTransformerStack stack)
forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Ordering
forall stack.
Ord stack =>
GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
min :: GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
$cmin :: forall stack.
Ord stack =>
GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
max :: GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
$cmax :: forall stack.
Ord stack =>
GTransformerStack stack
-> GTransformerStack stack -> GTransformerStack stack
>= :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c>= :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
> :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c> :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
<= :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c<= :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
< :: GTransformerStack stack -> GTransformerStack stack -> Bool
$c< :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Bool
compare :: GTransformerStack stack -> GTransformerStack stack -> Ordering
$ccompare :: forall stack.
Ord stack =>
GTransformerStack stack -> GTransformerStack stack -> Ordering
Ord, Int -> GTransformerStack stack -> ShowS
forall stack. Show stack => Int -> GTransformerStack stack -> ShowS
forall stack. Show stack => [GTransformerStack stack] -> ShowS
forall stack. Show stack => GTransformerStack stack -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GTransformerStack stack] -> ShowS
$cshowList :: forall stack. Show stack => [GTransformerStack stack] -> ShowS
show :: GTransformerStack stack -> String
$cshow :: forall stack. Show stack => GTransformerStack stack -> String
showsPrec :: Int -> GTransformerStack stack -> ShowS
$cshowsPrec :: forall stack. Show stack => Int -> GTransformerStack stack -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall stack x.
Rep (GTransformerStack stack) x -> GTransformerStack stack
forall stack x.
GTransformerStack stack -> Rep (GTransformerStack stack) x
$cto :: forall stack x.
Rep (GTransformerStack stack) x -> GTransformerStack stack
$cfrom :: forall stack x.
GTransformerStack stack -> Rep (GTransformerStack stack) x
Generic)
type instance
ModelSpec (GTransformerStack stack) =
GTransformerStack (ModelSpec stack)
type family
EncoderStackF
(style :: TransformerStyle)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout)
where
EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout =
GTransformerStack
( VS.Vector
numLayers
(EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout)
)
encoderStackSpec ::
forall style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout.
STransformerStyle style ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim queryEmbedDim ->
SDim ffnDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout)
encoderStackSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(queryEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(EncoderStackF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
ffnDim
hasDropout)
encoderStackSpec STransformerStyle style
style numLayers :: SNat numLayers
numLayers@SNat numLayers
SNat SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let blockSpec :: ModelSpec
(EncoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
ffnDim
hasDropout)
blockSpec = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(queryEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(EncoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
ffnDim
hasDropout)
encoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
in forall stack. stack -> GTransformerStack stack
GTransformerStack forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) a.
SNat n -> Vector n (ModelSpec a) -> VectorSpec n a
VectorSpec SNat numLayers
numLayers (forall (n :: Natural) a (p :: Natural -> *).
KnownNat n =>
p n -> a -> Vector n a
VS.replicate' SNat numLayers
numLayers ModelSpec
(EncoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
ffnDim
hasDropout)
blockSpec)
type family
DecoderStackF
(style :: TransformerStyle)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout)
where
DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout =
GTransformerStack
( VS.Vector
numLayers
(DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout)
)
decoderStackSpec ::
forall style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout.
STransformerStyle style ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim queryEmbedDim ->
SDim keyEmbedDim ->
SDim ffnDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout)
decoderStackSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(queryEmbedDim :: Dim (Name Symbol) (Size Natural))
(keyEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(DecoderStackF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
ffnDim
hasDropout)
decoderStackSpec STransformerStyle style
style numLayers :: SNat numLayers
numLayers@SNat numLayers
SNat SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let blockSpec :: ModelSpec
(DecoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
ffnDim
hasDropout)
blockSpec = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(queryEmbedDim :: Dim (Name Symbol) (Size Natural))
(keyEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(DecoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
ffnDim
hasDropout)
decoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
in forall stack. stack -> GTransformerStack stack
GTransformerStack forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) a.
SNat n -> Vector n (ModelSpec a) -> VectorSpec n a
VectorSpec SNat numLayers
numLayers (forall (n :: Natural) a (p :: Natural -> *).
KnownNat n =>
p n -> a -> Vector n a
VS.replicate' SNat numLayers
numLayers ModelSpec
(DecoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
ffnDim
hasDropout)
blockSpec)
instance
( HasInitialize block generatorDevice block' generatorDevice,
numLayers' ~ (numLayers + 1)
) =>
HasInitialize
(GTransformerStack (VS.Vector numLayers' block))
generatorDevice
(GTransformerStack (VS.Vector numLayers' block'))
generatorDevice
instance
HasStateDict block =>
HasStateDict (GTransformerStack (VS.Vector numLayers block))
instance
HasForward
(GTransformerStack (VS.Vector 0 block))
(query, attentionBias)
generatorDevice
query
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 0 block)
-> (query, attentionBias)
-> Generator generatorDevice
-> m (query, Generator generatorDevice)
forward GTransformerStack (Vector 0 block)
_ (query
query, attentionBias
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (query
query,)
instance
HasForward
block
(query, attentionBias)
generatorDevice
output
generatorOutputDevice =>
HasForward
(GTransformerStack (VS.Vector 1 block))
(query, attentionBias)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 1 block)
-> (query, attentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query, attentionBias)
input Generator generatorDevice
g =
let Just (block
block, Vector block
_) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
in forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query, attentionBias)
input Generator generatorDevice
g
instance
HasForward
(GTransformerStack (VS.Vector 0 block))
(query, key, attentionBias, crossAttentionBias)
generator
query
generator
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 0 block)
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generator
-> m (query, Generator generator)
forward GTransformerStack (Vector 0 block)
_ (query
query, key
_, attentionBias
_, crossAttentionBias
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (query
query,)
instance
HasForward
block
(query, key, attentionBias, crossAttentionBias)
generatorDevice
output
generatorOutputDevice =>
HasForward
(GTransformerStack (VS.Vector 1 block))
(query, key, attentionBias, crossAttentionBias)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector 1 block)
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query, key, attentionBias, crossAttentionBias)
input Generator generatorDevice
g =
let Just (block
block, Vector block
_) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
in forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query, key, attentionBias, crossAttentionBias)
input Generator generatorDevice
g
instance
{-# OVERLAPPABLE #-}
( HasForward
block
(query, attentionBias)
generatorDevice
output
generatorOutputDevice,
HasForward
block
(output, attentionBias)
generatorOutputDevice
output
generatorOutputDevice
) =>
HasForward
(GTransformerStack (VS.Vector n block))
(query, attentionBias)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector n block)
-> (query, attentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query
query, attentionBias
attentionBias) Generator generatorDevice
g =
let Just (block
block, Vector block
blocks) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
in forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl
( \m (output, Generator generatorOutputDevice)
agg block
block' -> do
(output
output, Generator generatorOutputDevice
g') <- m (output, Generator generatorOutputDevice)
agg
(output
output', Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block' (output
output, attentionBias
attentionBias) Generator generatorOutputDevice
g'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output', Generator generatorOutputDevice
g'')
)
( do
(output
output, Generator generatorOutputDevice
g') <- forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query
query, attentionBias
attentionBias) Generator generatorDevice
g
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output, Generator generatorOutputDevice
g')
)
Vector block
blocks
instance
{-# OVERLAPPABLE #-}
( HasForward
block
(query, key, attentionBias, crossAttentionBias)
generatorDevice
output
generatorOutputDevice,
HasForward
block
(output, key, attentionBias, crossAttentionBias)
generatorOutputDevice
output
generatorOutputDevice
) =>
HasForward
(GTransformerStack (VS.Vector n block))
(query, key, attentionBias, crossAttentionBias)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerStack (Vector n block)
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (GTransformerStack (VGS.Vector Vector block
v)) (query
query, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) Generator generatorDevice
g =
let Just (block
block, Vector block
blocks) = forall a. Vector a -> Maybe (a, Vector a)
V.uncons Vector block
v
in forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl
( \m (output, Generator generatorOutputDevice)
agg block
block' -> do
(output
output, Generator generatorOutputDevice
g') <- m (output, Generator generatorOutputDevice)
agg
(output
output', Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block' (output
output, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) Generator generatorOutputDevice
g'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output', Generator generatorOutputDevice
g'')
)
( do
(output
output, Generator generatorOutputDevice
g') <- forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward block
block (query
query, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) Generator generatorDevice
g
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
output, Generator generatorOutputDevice
g')
)
Vector block
blocks