{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Torch.GraduallyTyped.NN.Transformer.GBlock where
import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Transformer.GCrossAttention (GCrossAttentionF, crossAttentionSpec)
import Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork (GTransformerFeedForwardNetworkF, transformerFeedForwardNetworkSpec)
import Torch.GraduallyTyped.NN.Transformer.GSelfAttention (GSelfAttentionF, selfAttentionSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Size (..))
data
GTransformerBlock
(selfAttention :: Type)
(crossAttention :: Type)
(feedForwardNetwork :: Type)
where
GTransformerBlock ::
forall selfAttention crossAttention feedForwardNetwork.
{
forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> selfAttention
tbSelfAttention :: selfAttention,
forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> crossAttention
tbCrossAttention :: crossAttention,
forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> feedForwardNetwork
tbFeedForwardNetwork :: feedForwardNetwork
} ->
GTransformerBlock selfAttention crossAttention feedForwardNetwork
deriving stock (GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall selfAttention crossAttention feedForwardNetwork.
(Eq selfAttention, Eq crossAttention, Eq feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
/= :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
$c/= :: forall selfAttention crossAttention feedForwardNetwork.
(Eq selfAttention, Eq crossAttention, Eq feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
== :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
$c== :: forall selfAttention crossAttention feedForwardNetwork.
(Eq selfAttention, Eq crossAttention, Eq feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
Eq, GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {selfAttention} {crossAttention} {feedForwardNetwork}.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
Eq
(GTransformerBlock selfAttention crossAttention feedForwardNetwork)
forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Ordering
forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
min :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
$cmin :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
max :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
$cmax :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
>= :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
$c>= :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
> :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
$c> :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
<= :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
$c<= :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
< :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
$c< :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Bool
compare :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Ordering
$ccompare :: forall selfAttention crossAttention feedForwardNetwork.
(Ord selfAttention, Ord crossAttention, Ord feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> Ordering
Ord, Int
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
Show feedForwardNetwork) =>
Int
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> ShowS
forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
Show feedForwardNetwork) =>
[GTransformerBlock selfAttention crossAttention feedForwardNetwork]
-> ShowS
forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
Show feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> String
showList :: [GTransformerBlock selfAttention crossAttention feedForwardNetwork]
-> ShowS
$cshowList :: forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
Show feedForwardNetwork) =>
[GTransformerBlock selfAttention crossAttention feedForwardNetwork]
-> ShowS
show :: GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> String
$cshow :: forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
Show feedForwardNetwork) =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> String
showsPrec :: Int
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> ShowS
$cshowsPrec :: forall selfAttention crossAttention feedForwardNetwork.
(Show selfAttention, Show crossAttention,
Show feedForwardNetwork) =>
Int
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall selfAttention crossAttention feedForwardNetwork x.
Rep
(GTransformerBlock selfAttention crossAttention feedForwardNetwork)
x
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
forall selfAttention crossAttention feedForwardNetwork x.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> Rep
(GTransformerBlock selfAttention crossAttention feedForwardNetwork)
x
$cto :: forall selfAttention crossAttention feedForwardNetwork x.
Rep
(GTransformerBlock selfAttention crossAttention feedForwardNetwork)
x
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
$cfrom :: forall selfAttention crossAttention feedForwardNetwork x.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> Rep
(GTransformerBlock selfAttention crossAttention feedForwardNetwork)
x
Generic)
type instance
ModelSpec (GTransformerBlock selfAttention crossAttention feedForwardNetwork) =
GTransformerBlock (ModelSpec selfAttention) (ModelSpec crossAttention) (ModelSpec feedForwardNetwork)
type family
EncoderBlockF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout)
where
EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout =
GTransformerBlock
(NamedModel (GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout))
()
(NamedModel (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout))
encoderBlockSpec ::
forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout.
STransformerStyle style ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim queryEmbedDim ->
SDim ffnDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (EncoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim ffnDim hasDropout)
encoderBlockSpec :: forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(EncoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
ffnDim
hasDropout)
encoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let saSpec :: STransformerStyle style
-> NamedModel
(GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'T5
ST5
saSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'ByT5
SByT5
saSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'BART
SBART
saSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'MBART
SMBART
saSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'Pegasus
SPegasus
saSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"attention." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'BERT
SBERT
saSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"attention." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'RoBERTa
SRoBERTa
saSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
caSpec :: p -> ()
caSpec p
_ = ()
ffnSpec :: STransformerStyle style
-> NamedModel
(GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF
style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'T5
ST5
ffnSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'ByT5
SByT5
ffnSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'BART
SBART
ffnSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'MBART
SMBART
ffnSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'Pegasus
SPegasus
ffnSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'BERT
SBERT
ffnSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'RoBERTa
SRoBERTa
ffnSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall selfAttention crossAttention feedForwardNetwork.
selfAttention
-> crossAttention
-> feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
GTransformerBlock (STransformerStyle style
-> NamedModel
(GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
style) (forall {p}. p -> ()
caSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
(GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF
style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
style)
where
saSpec' :: _
saSpec' :: STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GSelfAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
hasDropout)
selfAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
ffnSpec' :: _
ffnSpec' :: STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GTransformerFeedForwardNetworkF
style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
type family
DecoderBlockF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout)
where
DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout =
GTransformerBlock
(NamedModel (GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout))
(NamedModel (GCrossAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim hasDropout))
(NamedModel (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout))
decoderBlockSpec ::
forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout.
STransformerStyle style ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim queryEmbedDim ->
SDim keyEmbedDim ->
SDim ffnDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (DecoderBlockF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim ffnDim hasDropout)
decoderBlockSpec :: forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(DecoderBlockF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
ffnDim
hasDropout)
decoderBlockSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let saSpec :: STransformerStyle style
-> NamedModel
(GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'T5
ST5
saSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.0." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'ByT5
SByT5
saSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'BART
SBART
saSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'MBART
SMBART
saSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle 'Pegasus
SPegasus
saSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
saSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
saSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
caSpec :: STransformerStyle style
-> NamedModel
(GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim)))
caSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'T5
ST5
caSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.1." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'ByT5
SByT5
caSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'BART
SBART
caSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'MBART
SMBART
caSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle 'Pegasus
SPegasus
caSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
caSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
caSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
ffnSpec :: STransformerStyle style
-> NamedModel
(GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF
style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.2." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'T5
ST5
ffnSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer.2." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'ByT5
SByT5
ffnSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'BART
SBART
ffnSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'MBART
SMBART
ffnSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle 'Pegasus
SPegasus
ffnSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
ffnSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
ffnSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall selfAttention crossAttention feedForwardNetwork.
selfAttention
-> crossAttention
-> feedForwardNetwork
-> GTransformerBlock
selfAttention crossAttention feedForwardNetwork
GTransformerBlock (STransformerStyle style
-> NamedModel
(GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim)))
saSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
(GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim)))
caSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
(GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF
style gradient device dataType queryEmbedDim)))
ffnSpec STransformerStyle style
style)
where
saSpec' :: _
saSpec' :: STransformerStyle style
-> GSelfAttention
(ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (SADropoutF style hasDropout))
(ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim))
saSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GSelfAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
hasDropout)
selfAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
caSpec' :: _
caSpec' :: STransformerStyle style
-> GCrossAttention
(ModelSpec
(CAInitialLayerNormF style gradient device dataType queryEmbedDim))
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType keyEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout))))
(ModelSpec (CADropoutF style hasDropout))
(ModelSpec
(CAFinalLayerNormF style gradient device dataType queryEmbedDim))
caSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GCrossAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
hasDropout)
crossAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
ffnSpec' :: _
ffnSpec' :: STransformerStyle style
-> GTransformerFeedForwardNetwork
(ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim))
(ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNActivationF style))
(ModelSpec (FFNActivationDropoutF style hasDropout))
(ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim))
(ModelSpec (FFNOutputDropoutF style hasDropout))
(ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim))
ffnSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GTransformerFeedForwardNetworkF
style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
instance
( HasInitialize selfAttention generatorDevice selfAttention' generatorDevice0,
HasInitialize crossAttention generatorDevice0 crossAttention' generatorDevice1,
HasInitialize feedForwardNetwork generatorDevice1 feedForwardNetwork' generatorOutputDevice
) =>
HasInitialize
(GTransformerBlock selfAttention crossAttention feedForwardNetwork)
generatorDevice
(GTransformerBlock selfAttention' crossAttention' feedForwardNetwork')
generatorOutputDevice
instance
( HasStateDict selfAttention,
HasStateDict crossAttention,
HasStateDict feedForwardNetwork
) =>
HasStateDict (GTransformerBlock selfAttention crossAttention feedForwardNetwork)
instance
( HasForward
selfAttention
(query, attentionBias)
generatorDevice
tensor0
generatorDevice0,
HasForward
feedForwardNetwork
tensor0
generatorDevice0
output
generatorOutputDevice
) =>
HasForward
(GTransformerBlock selfAttention () feedForwardNetwork)
(query, attentionBias)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerBlock selfAttention () feedForwardNetwork
-> (query, attentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformerBlock {selfAttention
feedForwardNetwork
()
tbFeedForwardNetwork :: feedForwardNetwork
tbCrossAttention :: ()
tbSelfAttention :: selfAttention
tbFeedForwardNetwork :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> feedForwardNetwork
tbCrossAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> crossAttention
tbSelfAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> selfAttention
..} (query
query, attentionBias
attentionBias) =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (query
query, attentionBias
attentionBias)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward selfAttention
tbSelfAttention
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward feedForwardNetwork
tbFeedForwardNetwork
instance
( HasForward
selfAttention
(query, attentionBias)
generatorDevice
tensor0
generatorDevice0,
HasForward
crossAttention
(tensor0, key, crossAttentionBias)
generatorDevice0
tensor1
generatorDevice1,
HasForward
feedForwardNetwork
tensor1
generatorDevice1
output
generatorOutputDevice
) =>
HasForward
(GTransformerBlock selfAttention crossAttention feedForwardNetwork)
(query, key, attentionBias, crossAttentionBias)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> (query, key, attentionBias, crossAttentionBias)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformerBlock {selfAttention
crossAttention
feedForwardNetwork
tbFeedForwardNetwork :: feedForwardNetwork
tbCrossAttention :: crossAttention
tbSelfAttention :: selfAttention
tbFeedForwardNetwork :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> feedForwardNetwork
tbCrossAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> crossAttention
tbSelfAttention :: forall selfAttention crossAttention feedForwardNetwork.
GTransformerBlock selfAttention crossAttention feedForwardNetwork
-> selfAttention
..} (query
query, key
key, attentionBias
attentionBias, crossAttentionBias
crossAttentionBias) =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (query
query, attentionBias
attentionBias)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward selfAttention
tbSelfAttention
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\tensor0
query' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward crossAttention
tbCrossAttention forall a b. (a -> b) -> a -> b
$ (tensor0
query', key
key, crossAttentionBias
crossAttentionBias))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward feedForwardNetwork
tbFeedForwardNetwork