{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Transformer.GSelfAttention where
import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention (GMultiHeadAttentionF, multiHeadAttentionSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, SShape (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
data
GSelfAttention
(initialLayerNorm :: Type)
(mha :: Type)
(dropout :: Type)
(finalLayerNorm :: Type)
where
GSelfAttention ::
forall initialLayerNorm mha dropout finalLayerNorm.
{
forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> initialLayerNorm
saInitialLayerNorm :: initialLayerNorm,
forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm -> mha
saMultiHeadAttention :: mha,
forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> dropout
saDropout :: dropout,
forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> finalLayerNorm
saFinalLayerNorm :: finalLayerNorm
} ->
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
deriving stock (GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall initialLayerNorm mha dropout finalLayerNorm.
(Eq initialLayerNorm, Eq mha, Eq dropout, Eq finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
/= :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c/= :: forall initialLayerNorm mha dropout finalLayerNorm.
(Eq initialLayerNorm, Eq mha, Eq dropout, Eq finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
== :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c== :: forall initialLayerNorm mha dropout finalLayerNorm.
(Eq initialLayerNorm, Eq mha, Eq dropout, Eq finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
Eq, GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {initialLayerNorm} {mha} {dropout} {finalLayerNorm}.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
Eq (GSelfAttention initialLayerNorm mha dropout finalLayerNorm)
forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
min :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
$cmin :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
max :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
$cmax :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
>= :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c>= :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
> :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c> :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
<= :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c<= :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
< :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
$c< :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Bool
compare :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
$ccompare :: forall initialLayerNorm mha dropout finalLayerNorm.
(Ord initialLayerNorm, Ord mha, Ord dropout, Ord finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Ordering
Ord, Int
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
Show finalLayerNorm) =>
Int
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
Show finalLayerNorm) =>
[GSelfAttention initialLayerNorm mha dropout finalLayerNorm]
-> ShowS
forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
Show finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> String
showList :: [GSelfAttention initialLayerNorm mha dropout finalLayerNorm]
-> ShowS
$cshowList :: forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
Show finalLayerNorm) =>
[GSelfAttention initialLayerNorm mha dropout finalLayerNorm]
-> ShowS
show :: GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> String
$cshow :: forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
Show finalLayerNorm) =>
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> String
showsPrec :: Int
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
$cshowsPrec :: forall initialLayerNorm mha dropout finalLayerNorm.
(Show initialLayerNorm, Show mha, Show dropout,
Show finalLayerNorm) =>
Int
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall initialLayerNorm mha dropout finalLayerNorm x.
Rep (GSelfAttention initialLayerNorm mha dropout finalLayerNorm) x
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
forall initialLayerNorm mha dropout finalLayerNorm x.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Rep
(GSelfAttention initialLayerNorm mha dropout finalLayerNorm) x
$cto :: forall initialLayerNorm mha dropout finalLayerNorm x.
Rep (GSelfAttention initialLayerNorm mha dropout finalLayerNorm) x
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
$cfrom :: forall initialLayerNorm mha dropout finalLayerNorm x.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> Rep
(GSelfAttention initialLayerNorm mha dropout finalLayerNorm) x
Generic)
type instance
ModelSpec (GSelfAttention initialLayerNorm mha dropout finalLayerNorm) =
GSelfAttention (ModelSpec initialLayerNorm) (ModelSpec mha) (ModelSpec dropout) (ModelSpec finalLayerNorm)
type family
GSelfAttentionF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout =
GSelfAttention
(SAInitialLayerNormF style gradient device dataType queryEmbedDim)
(SAMultiheadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout)
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType queryEmbedDim)
type family
SAInitialLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
SAInitialLayerNormF 'T5 gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[queryEmbedDim]))
SAInitialLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
SAInitialLayerNormF 'T5 gradient device dataType queryEmbedDim
SAInitialLayerNormF 'BART _ _ _ _ =
()
SAInitialLayerNormF 'MBART gradient device dataType queryEmbedDim =
SAInitialLayerNormF 'BART gradient device dataType queryEmbedDim
SAInitialLayerNormF 'Pegasus gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
SAInitialLayerNormF 'BERT _ _ _ _ =
()
SAInitialLayerNormF 'RoBERTa gradient device dataType queryEmbedDim =
SAInitialLayerNormF 'BERT gradient device dataType queryEmbedDim
type family
SAMultiheadAttentionF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
SAMultiheadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout =
NamedModel (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim queryEmbedDim queryEmbedDim hasDropout)
type family
SADropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
SADropoutF _ 'WithDropout = Dropout
SADropoutF _ 'WithoutDropout = ()
type family
SAFinalLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
SAFinalLayerNormF 'T5 _ _ _ _ =
()
SAFinalLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
SAFinalLayerNormF 'T5 gradient device dataType queryEmbedDim
SAFinalLayerNormF 'BART gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
SAFinalLayerNormF 'MBART gradient device dataType queryEmbedDim =
SAFinalLayerNormF 'BART gradient device dataType queryEmbedDim
SAFinalLayerNormF 'Pegasus gradient device dataType queryEmbedDim =
()
SAFinalLayerNormF 'BERT gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
SAFinalLayerNormF 'RoBERTa gradient device dataType queryEmbedDim =
SAFinalLayerNormF 'BERT gradient device dataType queryEmbedDim
selfAttentionSpec ::
forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout.
STransformerStyle style ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim queryEmbedDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (GSelfAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim hasDropout)
selfAttentionSpec :: forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GSelfAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
hasDropout)
selfAttentionSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let initialLayerNormSpec :: STransformerStyle style
-> ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim)
initialLayerNormSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." LayerNormSpec
'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
initialLayerNormSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." LayerNormSpec
'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
initialLayerNormSpec STransformerStyle style
SBART = ()
initialLayerNormSpec STransformerStyle style
SMBART = ()
initialLayerNormSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"self_attn_layer_norm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
initialLayerNormSpec STransformerStyle style
SBERT = ()
initialLayerNormSpec STransformerStyle style
SRoBERTa = ()
initialLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
mhaSpec :: STransformerStyle style
-> NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout)))
mhaSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"SelfAttention." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle 'T5
ST5
mhaSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"SelfAttention." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle 'ByT5
SByT5
mhaSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"self_attn." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle 'BART
SBART
mhaSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"self_attn." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle 'MBART
SMBART
mhaSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"self_attn." forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle 'Pegasus
SPegasus
mhaSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle 'BERT
SBERT
mhaSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle 'RoBERTa
SRoBERTa
mhaSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
dropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (SADropoutF style hasDropout)
dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithoutDropout = ()
finalLayerNormSpec :: STransformerStyle style
-> ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim)
finalLayerNormSpec STransformerStyle style
ST5 = ()
finalLayerNormSpec STransformerStyle style
SByT5 = ()
finalLayerNormSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"self_attn_layer_norm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
finalLayerNormSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"self_attn_layer_norm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
finalLayerNormSpec STransformerStyle style
SPegasus = ()
finalLayerNormSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"output.LayerNorm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
finalLayerNormSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"output.LayerNorm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
finalLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall initialLayerNorm mha dropout finalLayerNorm.
initialLayerNorm
-> mha
-> dropout
-> finalLayerNorm
-> GSelfAttention initialLayerNorm mha dropout finalLayerNorm
GSelfAttention (STransformerStyle style
-> ModelSpec
(SAInitialLayerNormF style gradient device dataType queryEmbedDim)
initialLayerNormSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(KInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(VInProjF style gradient device dataType queryEmbedDim embedDim))
(ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim))
(ModelSpec (DropoutF style hasDropout)))
mhaSpec STransformerStyle style
style) (STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (SADropoutF style hasDropout)
dropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout) (STransformerStyle style
-> ModelSpec
(SAFinalLayerNormF style gradient device dataType queryEmbedDim)
finalLayerNormSpec STransformerStyle style
style)
where
mhaSpec' ::
STransformerStyle style ->
ModelSpec (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim queryEmbedDim queryEmbedDim hasDropout)
mhaSpec' :: STransformerStyle style
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
queryEmbedDim
queryEmbedDim
hasDropout)
mhaSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(valueEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim valueEmbedDim
-> SHasDropout hasDropout
-> Double
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
valueEmbedDim
hasDropout)
multiHeadAttentionSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim queryEmbedDim
queryEmbedDim SDim queryEmbedDim
queryEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP
layerNormWithoutBiasSpec :: LayerNormSpec
'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
layerNormWithBiasSpec :: LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
instance
( HasInitialize initialLayerNorm generatorDevice initialLayerNorm' generatorDevice0,
HasInitialize multiHeadAttention generatorDevice0 multiHeadAttention' generatorDevice1,
HasInitialize dropout generatorDevice1 dropout' generatorDevice2,
HasInitialize finalLayerNorm generatorDevice2 finalLayerNorm' generatorOutputDevice
) =>
HasInitialize
(GSelfAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm)
generatorDevice
(GSelfAttention initialLayerNorm' multiHeadAttention' dropout' finalLayerNorm')
generatorOutputDevice
instance
( HasStateDict initialLayerNorm,
HasStateDict multiHeadAttention,
HasStateDict dropout,
HasStateDict finalLayerNorm
) =>
HasStateDict (GSelfAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm)
instance
( HasForward
initialLayerNorm
(Tensor queryGradient queryLayout queryDevice queryDataType queryShape)
generatorDevice
tensor0
generatorDevice0,
HasForward
multiHeadAttention
( tensor0,
tensor0,
tensor0,
Tensor attentionBiasGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape
)
generatorDevice0
tensor1
generatorDevice1,
HasForward
dropout
tensor1
generatorDevice1
(Tensor gradient2 layout2 device2 dataType2 shape2)
generatorDevice2,
HasForward
finalLayerNorm
(Tensor (queryGradient <|> gradient2) (queryLayout <+> layout2) (queryDevice <+> device2) (queryDataType <+> dataType2) (BroadcastShapesF queryShape shape2))
generatorDevice2
output
generatorOutputDevice,
Catch (BroadcastShapesF queryShape shape2)
) =>
HasForward
(GSelfAttention initialLayerNorm multiHeadAttention dropout finalLayerNorm)
( Tensor queryGradient queryLayout queryDevice queryDataType queryShape,
Tensor attentionBiasGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape
)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GSelfAttention
initialLayerNorm multiHeadAttention dropout finalLayerNorm
-> (Tensor
queryGradient queryLayout queryDevice queryDataType queryShape,
Tensor
attentionBiasGradient
attentionBiasLayout
attentionBiasDevice
attentionBiasDataType
attentionBiasShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GSelfAttention {initialLayerNorm
multiHeadAttention
dropout
finalLayerNorm
saFinalLayerNorm :: finalLayerNorm
saDropout :: dropout
saMultiHeadAttention :: multiHeadAttention
saInitialLayerNorm :: initialLayerNorm
saFinalLayerNorm :: forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> finalLayerNorm
saDropout :: forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> dropout
saMultiHeadAttention :: forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm -> mha
saInitialLayerNorm :: forall initialLayerNorm mha dropout finalLayerNorm.
GSelfAttention initialLayerNorm mha dropout finalLayerNorm
-> initialLayerNorm
..} (Tensor
queryGradient queryLayout queryDevice queryDataType queryShape
query, Tensor
attentionBiasGradient
attentionBiasLayout
attentionBiasDevice
attentionBiasDataType
attentionBiasShape
attentionBias) =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
queryGradient queryLayout queryDevice queryDataType queryShape
query
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
saInitialLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\tensor0
query' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward multiHeadAttention
saMultiHeadAttention (tensor0
query', tensor0
query', tensor0
query', Tensor
attentionBiasGradient
attentionBiasLayout
attentionBiasDevice
attentionBiasDataType
attentionBiasShape
attentionBias))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward dropout
saDropout
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
queryGradient queryLayout queryDevice queryDataType queryShape
query forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
`add`)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
saFinalLayerNorm