{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin TypeLevel.Rewrite
-fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyRightAssociativeL #-}
module Torch.GraduallyTyped.NN.Transformer.GMultiHeadAttention where
import Control.Monad.Catch (MonadThrow)
import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (..))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Functional.NonLinearActivation (SoftmaxF, softmax)
import Torch.GraduallyTyped.NN.Linear (GLinear (..), GLinearF, linearSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF, sGetDimFromShape, sUnifyDim, type (!))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim (..), SSelectDim (..), SShape (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (ReshapeF, TransposeF, sReshape, sTranspose)
import Torch.GraduallyTyped.Tensor.MathOperations.BlasLapack (MatmulF, matmul)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add, mulScalar)
import Torch.GraduallyTyped.Tensor.Type (SGetShape (..), Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
data MultiHeadAttentionHasScaling
=
MultiHeadAttentionWithoutScaling
|
MultiHeadAttentionWithQueryScaling
|
MultiHeadAttentionWithWeightScaling
deriving stock (MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c/= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
== :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c== :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
Eq, Eq MultiHeadAttentionHasScaling
MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Ordering
MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
$cmin :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
max :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
$cmax :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> MultiHeadAttentionHasScaling
>= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c>= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
> :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c> :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
<= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c<= :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
< :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
$c< :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Bool
compare :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Ordering
$ccompare :: MultiHeadAttentionHasScaling
-> MultiHeadAttentionHasScaling -> Ordering
Ord, Int -> MultiHeadAttentionHasScaling -> ShowS
[MultiHeadAttentionHasScaling] -> ShowS
MultiHeadAttentionHasScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MultiHeadAttentionHasScaling] -> ShowS
$cshowList :: [MultiHeadAttentionHasScaling] -> ShowS
show :: MultiHeadAttentionHasScaling -> String
$cshow :: MultiHeadAttentionHasScaling -> String
showsPrec :: Int -> MultiHeadAttentionHasScaling -> ShowS
$cshowsPrec :: Int -> MultiHeadAttentionHasScaling -> ShowS
Show, forall x.
Rep MultiHeadAttentionHasScaling x -> MultiHeadAttentionHasScaling
forall x.
MultiHeadAttentionHasScaling -> Rep MultiHeadAttentionHasScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep MultiHeadAttentionHasScaling x -> MultiHeadAttentionHasScaling
$cfrom :: forall x.
MultiHeadAttentionHasScaling -> Rep MultiHeadAttentionHasScaling x
Generic)
type instance ModelSpec MultiHeadAttentionHasScaling = MultiHeadAttentionHasScaling
instance HasInitialize MultiHeadAttentionHasScaling generatorDevice MultiHeadAttentionHasScaling generatorDevice where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec MultiHeadAttentionHasScaling
-> Generator generatorDevice
-> m (MultiHeadAttentionHasScaling, Generator generatorDevice)
initialize ModelSpec MultiHeadAttentionHasScaling
hasScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec MultiHeadAttentionHasScaling
hasScaling, Generator generatorDevice
g)
instance HasStateDict MultiHeadAttentionHasScaling where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec MultiHeadAttentionHasScaling
-> StateDictKey -> m MultiHeadAttentionHasScaling
fromStateDict ModelSpec MultiHeadAttentionHasScaling
hasScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec MultiHeadAttentionHasScaling
hasScaling
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MultiHeadAttentionHasScaling -> m ()
toStateDict StateDictKey
_ MultiHeadAttentionHasScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
data
GMultiHeadAttention
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(qInProj :: Type)
(kInProj :: Type)
(vInProj :: Type)
(outProj :: Type)
(dropout :: Type)
where
GMultiHeadAttention ::
forall headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout.
{
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> SDim headDim
mhaHeadDim :: SDim headDim,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> SDim headEmbedDim
mhaHeadEmbedDim :: SDim headEmbedDim,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> SDim embedDim
mhaEmbedDim :: SDim embedDim,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> qInProj
mhaQInProj :: qInProj,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> kInProj
mhaKInProj :: kInProj,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> vInProj
mhaVInProj :: vInProj,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> outProj
mhaOutProj :: outProj,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> dropout
mhaDropout :: dropout,
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> MultiHeadAttentionHasScaling
mhaScaling :: MultiHeadAttentionHasScaling
} ->
GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout
deriving stock (Int
-> GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
Show dropout) =>
Int
-> GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> ShowS
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
Show dropout) =>
[GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout]
-> ShowS
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
Show dropout) =>
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> String
showList :: [GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout]
-> ShowS
$cshowList :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
Show dropout) =>
[GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout]
-> ShowS
show :: GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> String
$cshow :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
Show dropout) =>
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> String
showsPrec :: Int
-> GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> ShowS
$cshowsPrec :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
(Show qInProj, Show kInProj, Show vInProj, Show outProj,
Show dropout) =>
Int
-> GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout x.
Rep
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout)
x
-> GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout x.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> Rep
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout)
x
$cto :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout x.
Rep
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout)
x
-> GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
$cfrom :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout x.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> Rep
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout)
x
Generic)
type instance
ModelSpec (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout) =
GMultiHeadAttention headDim headEmbedDim embedDim (ModelSpec qInProj) (ModelSpec kInProj) (ModelSpec vInProj) (ModelSpec outProj) (ModelSpec dropout)
type family
GMultiHeadAttentionF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(valueEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout =
GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType queryEmbedDim embedDim)
(KInProjF style gradient device dataType keyEmbedDim embedDim)
(VInProjF style gradient device dataType valueEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim queryEmbedDim)
(DropoutF style hasDropout)
type family
QInProjF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
QInProjF 'T5 gradient device dataType queryEmbedDim embedDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim embedDim)
QInProjF 'ByT5 gradient device dataType queryEmbedDim embedDim =
QInProjF 'T5 gradient device dataType queryEmbedDim embedDim
QInProjF _ gradient device dataType queryEmbedDim embedDim =
NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim embedDim)
type family
KInProjF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
KInProjF 'T5 gradient device dataType keyEmbedDim embedDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType keyEmbedDim embedDim)
KInProjF 'ByT5 gradient device dataType keyEmbedDim embedDim =
KInProjF 'T5 gradient device dataType keyEmbedDim embedDim
KInProjF _ gradient device dataType keyEmbedDim embedDim =
NamedModel (GLinearF 'WithBias gradient device dataType keyEmbedDim embedDim)
type family
VInProjF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(valueEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
VInProjF 'T5 gradient device dataType valueEmbedDim embedDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType valueEmbedDim embedDim)
VInProjF 'ByT5 gradient device dataType valueEmbedDim embedDim =
VInProjF 'T5 gradient device dataType valueEmbedDim embedDim
VInProjF _ gradient device dataType valueEmbedDim embedDim =
NamedModel (GLinearF 'WithBias gradient device dataType valueEmbedDim embedDim)
type family
OutProjF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
OutProjF 'T5 gradient device dataType embedDim queryEmbedDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType embedDim queryEmbedDim)
OutProjF 'ByT5 gradient device dataType embedDim queryEmbedDim =
OutProjF 'T5 gradient device dataType embedDim queryEmbedDim
OutProjF _ gradient device dataType embedDim queryEmbedDim =
NamedModel (GLinearF 'WithBias gradient device dataType embedDim queryEmbedDim)
type family
DropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
DropoutF _ 'WithDropout = Dropout
DropoutF _ 'WithoutDropout = ()
multiHeadAttentionSpec ::
forall style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout.
STransformerStyle style ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim queryEmbedDim ->
SDim keyEmbedDim ->
SDim valueEmbedDim ->
SHasDropout hasDropout ->
Double ->
ModelSpec (GMultiHeadAttentionF style gradient device dataType headDim headEmbedDim embedDim queryEmbedDim keyEmbedDim valueEmbedDim hasDropout)
multiHeadAttentionSpec :: forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(keyEmbedDim :: Dim (Name Symbol) (Size Nat))
(valueEmbedDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim valueEmbedDim
-> SHasDropout hasDropout
-> Double
-> ModelSpec
(GMultiHeadAttentionF
style
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
valueEmbedDim
hasDropout)
multiHeadAttentionSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim SDim keyEmbedDim
keyEmbedDim SDim valueEmbedDim
valueEmbedDim SHasDropout hasDropout
hasDropout Double
dropoutP =
let qInProjSpec :: STransformerStyle style
-> ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim)
qInProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
qInProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
qInProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
qInProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
qInProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"q_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
qInProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.query." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
qInProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.query." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim embedDim
embedDim)
qInProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
kInProjSpec :: STransformerStyle style
-> ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim)
kInProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
kInProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
kInProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
kInProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
kInProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"k_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
kInProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.key." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
kInProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.key." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim keyEmbedDim
keyEmbedDim SDim embedDim
embedDim)
kInProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
vInProjSpec :: STransformerStyle style
-> ModelSpec
(VInProjF style gradient device dataType valueEmbedDim embedDim)
vInProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
vInProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
vInProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
vInProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
vInProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"v_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
vInProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.value." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
vInProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"self.value." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim valueEmbedDim
valueEmbedDim SDim embedDim
embedDim)
vInProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
outProjSpec :: STransformerStyle style
-> ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim)
outProjSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"o." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
outProjSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"o." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
outProjSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"out_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
outProjSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"out_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
outProjSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"out_proj." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
outProjSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
outProjSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." (forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias SDim embedDim
embedDim SDim queryEmbedDim
queryEmbedDim)
outProjSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
dropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout -> ModelSpec (DropoutF style hasDropout)
dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
dropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithoutDropout = ()
scaling :: STransformerStyle style -> MultiHeadAttentionHasScaling
scaling :: STransformerStyle style -> MultiHeadAttentionHasScaling
scaling STransformerStyle style
ST5 = MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling
scaling STransformerStyle style
SByT5 = MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling
scaling STransformerStyle style
SBART = MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling
scaling STransformerStyle style
SMBART = MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling
scaling STransformerStyle style
SPegasus = MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling
scaling STransformerStyle style
SBERT = MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling
scaling STransformerStyle style
SRoBERTa = MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling
scaling STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> qInProj
-> kInProj
-> vInProj
-> outProj
-> dropout
-> MultiHeadAttentionHasScaling
-> GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
GMultiHeadAttention
SDim headDim
headDim
SDim headEmbedDim
headEmbedDim
SDim embedDim
embedDim
(STransformerStyle style
-> ModelSpec
(QInProjF style gradient device dataType queryEmbedDim embedDim)
qInProjSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(KInProjF style gradient device dataType keyEmbedDim embedDim)
kInProjSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(VInProjF style gradient device dataType valueEmbedDim embedDim)
vInProjSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(OutProjF style gradient device dataType embedDim queryEmbedDim)
outProjSpec STransformerStyle style
style)
(STransformerStyle style
-> SHasDropout hasDropout -> ModelSpec (DropoutF style hasDropout)
dropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
(STransformerStyle style -> MultiHeadAttentionHasScaling
scaling STransformerStyle style
style)
where
projSpecWithoutBias ::
forall inputDim outputDim.
SDim inputDim ->
SDim outputDim ->
ModelSpec
( GLinear
(NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
(NamedModel ())
)
projSpecWithoutBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
projSpecWithoutBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType
projSpecWithBias ::
forall inputDim outputDim.
SDim inputDim ->
SDim outputDim ->
ModelSpec
( GLinear
(NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
(NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
)
projSpecWithBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
projSpecWithBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType
instance
( HasInitialize qInProj generatorDevice qInProj' generatorDevice0,
HasInitialize kInProj generatorDevice0 kInProj' generatorDevice1,
HasInitialize vInProj generatorDevice1 vInProj' generatorDevice2,
HasInitialize outProj generatorDevice2 outProj' generatorDevice3,
HasInitialize dropout generatorDevice3 dropout' generatorOutputDevice
) =>
HasInitialize
(GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout)
generatorDevice
(GMultiHeadAttention headDim headEmbedDim embedDim qInProj' kInProj' vInProj' outProj' dropout')
generatorOutputDevice
instance
( HasStateDict qInProj,
HasStateDict vInProj,
HasStateDict kInProj,
HasStateDict outProj,
HasStateDict dropout
) =>
HasStateDict (GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout)
type BatchDim ::
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Dim (Name Symbol) (Size Nat)
type BatchDim queryShape keyShape valueShape =
(queryShape ! 0) <+> (keyShape ! 0) <+> (valueShape ! 0)
getBatchDim ::
forall m queryShape keyShape valueShape batchDim.
(MonadThrow m, batchDim ~ BatchDim queryShape keyShape valueShape) =>
SShape queryShape ->
SShape keyShape ->
SShape valueShape ->
m (SDim batchDim)
getBatchDim :: forall (m :: * -> *)
(queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
(keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
(valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
(batchDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m,
batchDim ~ BatchDim queryShape keyShape valueShape) =>
SShape queryShape
-> SShape keyShape -> SShape valueShape -> m (SDim batchDim)
getBatchDim SShape queryShape
queryShape SShape keyShape
keyShape SShape valueShape
valueShape = do
SDim (GetDimF ('SelectDim ('ByIndex 0)) queryShape)
queryBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape queryShape
queryShape
SDim (GetDimF ('SelectDim ('ByIndex 0)) keyShape)
keyBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape keyShape
keyShape
SDim (GetDimF ('SelectDim ('ByIndex 0)) valueShape)
valueBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape valueShape
valueShape
SDim
(Unify
(Dim (Name Symbol) (Size Nat))
(GetDimF ('SelectDim ('ByIndex 0)) keyShape)
(GetDimF ('SelectDim ('ByIndex 0)) valueShape))
keyValueBatchDim <- forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
(dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim SDim (GetDimF ('SelectDim ('ByIndex 0)) keyShape)
keyBatchDim SDim (GetDimF ('SelectDim ('ByIndex 0)) valueShape)
valueBatchDim
forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
(dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim SDim (GetDimF ('SelectDim ('ByIndex 0)) queryShape)
queryBatchDim SDim
(Unify
(Dim (Name Symbol) (Size Nat))
(GetDimF ('SelectDim ('ByIndex 0)) keyShape)
(GetDimF ('SelectDim ('ByIndex 0)) valueShape))
keyValueBatchDim
type QuerySeqDim ::
Shape [Dim (Name Symbol) (Size Nat)] ->
Dim (Name Symbol) (Size Nat)
type QuerySeqDim queryShape =
queryShape ! 1
getQuerySeqDim ::
forall m queryShape querySeqDim.
(MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) =>
SShape queryShape ->
m (SDim querySeqDim)
getQuerySeqDim :: forall (m :: * -> *)
(queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
(querySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) =>
SShape queryShape -> m (SDim querySeqDim)
getQuerySeqDim = forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)
type KeySeqDim ::
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Dim (Name Symbol) (Size Nat)
type KeySeqDim keyShape valueShape =
(keyShape ! 1) <+> (valueShape ! 1)
getKeySeqDim ::
forall m keyShape valueShape keySeqDim.
(MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) =>
SShape keyShape ->
SShape valueShape ->
m (SDim keySeqDim)
getKeySeqDim :: forall (m :: * -> *)
(keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
(valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
(keySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) =>
SShape keyShape -> SShape valueShape -> m (SDim keySeqDim)
getKeySeqDim SShape keyShape
keyShape SShape valueShape
valueShape =
do
SDim (GetDimF ('SelectDim ('ByIndex 1)) keyShape)
keySeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape keyShape
keyShape
SDim (GetDimF ('SelectDim ('ByIndex 1)) valueShape)
valueSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape valueShape
valueShape
forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
(dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim SDim (GetDimF ('SelectDim ('ByIndex 1)) keyShape)
keySeqDim SDim (GetDimF ('SelectDim ('ByIndex 1)) valueShape)
valueSeqDim
instance
( HasForward
qInProj
(Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape)
generatorDevice
(Tensor qRequiresGradient qLayout qDevice qDataType qShape0)
qGeneratorOutputDevice,
reshapedQShape0 ~ ReshapeF qShape0 ('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]),
Catch reshapedQShape0,
qShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) reshapedQShape0,
Catch qShape,
HasForward
kInProj
(Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape)
qGeneratorOutputDevice
(Tensor qRequiresGradient kLayout kDevice kDataType kShape0)
kGeneratorOutputDevice,
reshapedKShape0 ~ ReshapeF kShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]),
Catch reshapedKShape0,
transposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) reshapedKShape0,
Catch transposedReshapedKShape0,
doubleTransposedReshapedKShape0 ~ TransposeF ('SelectDim ('ByIndex 2)) ('SelectDim ('ByIndex 3)) transposedReshapedKShape0,
Catch doubleTransposedReshapedKShape0,
multipliedQDoubleTransposedReshapedKShape0 ~ MatmulF qShape doubleTransposedReshapedKShape0,
Catch multipliedQDoubleTransposedReshapedKShape0,
weightsShape0
~ SoftmaxF
('SelectDim ('ByIndex 3))
(BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape),
Catch (BroadcastShapesF multipliedQDoubleTransposedReshapedKShape0 attentionBiasShape),
Catch weightsShape0,
HasForward
dropout
( Tensor
(qRequiresGradient <|> attentionBiasRequiresGradient)
(qLayout <+> kLayout <+> attentionBiasLayout)
(qDevice <+> kDevice <+> attentionBiasDevice)
(qDataType <+> kDataType <+> attentionBiasDataType)
weightsShape0
)
kGeneratorOutputDevice
(Tensor weightsRequiresGradient weightsLayout weightsDevice weightsDataType weightsShape)
weightsGeneratorOutputDevice,
HasForward
vInProj
(Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape)
weightsGeneratorOutputDevice
(Tensor weightsRequiresGradient vLayout vDevice vDataType vShape0)
vGeneratorOutputDevice,
reshapedVShape0 ~ ReshapeF vShape0 ('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]),
Catch reshapedVShape0,
transposedReshapedVShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) reshapedVShape0,
Catch transposedReshapedVShape,
multipliedWeightsTransposedReshapedVShape ~ MatmulF weightsShape transposedReshapedVShape,
Catch multipliedWeightsTransposedReshapedVShape,
outputQueryShape0 ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) multipliedWeightsTransposedReshapedVShape,
Catch outputQueryShape0,
HasForward
outProj
( Tensor
weightsRequiresGradient
(weightsLayout <+> vLayout)
(weightsDevice <+> vDevice)
(weightsDataType <+> vDataType)
reshapedOutputQueryShape0
)
vGeneratorOutputDevice
output
generatorOutputDevice,
reshapedOutputQueryShape0 ~ ReshapeF outputQueryShape0 ('Shape '[batchDim, querySeqDim, embedDim]),
Catch reshapedOutputQueryShape0,
SGetShape queryShape,
SGetShape keyShape,
SGetShape valueShape,
batchDim ~ BatchDim queryShape keyShape valueShape,
querySeqDim ~ QuerySeqDim queryShape,
keySeqDim ~ KeySeqDim keyShape valueShape
) =>
HasForward
(GMultiHeadAttention headDim headEmbedDim embedDim qInProj kInProj vInProj outProj dropout)
( Tensor queryRequiresGradient queryLayout queryDevice queryDataType queryShape,
Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape,
Tensor valueRequiresGradient valueLayout valueDevice valueDataType valueShape,
Tensor attentionBiasRequiresGradient attentionBiasLayout attentionBiasDevice attentionBiasDataType attentionBiasShape
)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> (Tensor
queryRequiresGradient
queryLayout
queryDevice
queryDataType
queryShape,
Tensor
keyRequiresGradient keyLayout keyDevice keyDataType keyShape,
Tensor
valueRequiresGradient
valueLayout
valueDevice
valueDataType
valueShape,
Tensor
attentionBiasRequiresGradient
attentionBiasLayout
attentionBiasDevice
attentionBiasDataType
attentionBiasShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GMultiHeadAttention {qInProj
kInProj
dropout
vInProj
outProj
SDim headDim
SDim headEmbedDim
SDim embedDim
MultiHeadAttentionHasScaling
mhaScaling :: MultiHeadAttentionHasScaling
mhaDropout :: dropout
mhaOutProj :: outProj
mhaVInProj :: vInProj
mhaKInProj :: kInProj
mhaQInProj :: qInProj
mhaEmbedDim :: SDim embedDim
mhaHeadEmbedDim :: SDim headEmbedDim
mhaHeadDim :: SDim headDim
mhaScaling :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> MultiHeadAttentionHasScaling
mhaDropout :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> dropout
mhaOutProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> outProj
mhaVInProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> vInProj
mhaKInProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> kInProj
mhaQInProj :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> qInProj
mhaEmbedDim :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> SDim embedDim
mhaHeadEmbedDim :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> SDim headEmbedDim
mhaHeadDim :: forall (headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat)) qInProj kInProj vInProj
outProj dropout.
GMultiHeadAttention
headDim
headEmbedDim
embedDim
qInProj
kInProj
vInProj
outProj
dropout
-> SDim headDim
..} (Tensor
queryRequiresGradient
queryLayout
queryDevice
queryDataType
queryShape
query, Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key, Tensor
valueRequiresGradient
valueLayout
valueDevice
valueDataType
valueShape
value, Tensor
attentionBiasRequiresGradient
attentionBiasLayout
attentionBiasDevice
attentionBiasDataType
attentionBiasShape
attentionBias) Generator generatorDevice
g = do
SDim batchDim
batchDim <-
let queryShape :: SShape queryShape
queryShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
queryRequiresGradient
queryLayout
queryDevice
queryDataType
queryShape
query
keyShape :: SShape keyShape
keyShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key
valueShape :: SShape valueShape
valueShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
valueRequiresGradient
valueLayout
valueDevice
valueDataType
valueShape
value
in forall (m :: * -> *)
(queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
(keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
(valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
(batchDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m,
batchDim ~ BatchDim queryShape keyShape valueShape) =>
SShape queryShape
-> SShape keyShape -> SShape valueShape -> m (SDim batchDim)
getBatchDim SShape queryShape
queryShape SShape keyShape
keyShape SShape valueShape
valueShape
SDim querySeqDim
querySeqDim <-
let queryShape :: SShape queryShape
queryShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
queryRequiresGradient
queryLayout
queryDevice
queryDataType
queryShape
query
in forall (m :: * -> *)
(queryShape :: Shape [Dim (Name Symbol) (Size Nat)])
(querySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, querySeqDim ~ QuerySeqDim queryShape) =>
SShape queryShape -> m (SDim querySeqDim)
getQuerySeqDim SShape queryShape
queryShape
SDim keySeqDim
keySeqDim <-
let keyShape :: SShape keyShape
keyShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key
valueShape :: SShape valueShape
valueShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
valueRequiresGradient
valueLayout
valueDevice
valueDataType
valueShape
value
in forall (m :: * -> *)
(keyShape :: Shape [Dim (Name Symbol) (Size Nat)])
(valueShape :: Shape [Dim (Name Symbol) (Size Nat)])
(keySeqDim :: Dim (Name Symbol) (Size Nat)).
(MonadThrow m, keySeqDim ~ KeySeqDim keyShape valueShape) =>
SShape keyShape -> SShape valueShape -> m (SDim keySeqDim)
getKeySeqDim SShape keyShape
keyShape SShape valueShape
valueShape
let scaling :: Double
scaling = (Double
1 :: Double) forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim headEmbedDim
mhaHeadEmbedDim)
forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT Generator generatorDevice
g forall a b. (a -> b) -> a -> b
$
let q :: IxStateT
m
(Generator generatorDevice)
(Generator qGeneratorOutputDevice)
(Tensor
qRequiresGradient
qLayout
qDevice
qDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF qShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF querySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
qShape0
('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]))))
q =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
queryRequiresGradient
queryLayout
queryDevice
queryDataType
queryShape
query
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward qInProj
mhaQInProj
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
)
MultiHeadAttentionHasScaling
mhaScaling
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim querySeqDim
querySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headDim
mhaHeadDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headEmbedDim
mhaHeadEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
k :: IxStateT
m
(Generator qGeneratorOutputDevice)
(Generator kGeneratorOutputDevice)
(Tensor
qRequiresGradient
kLayout
kDevice
kDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF kShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF keySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
kShape0
('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
k =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor keyRequiresGradient keyLayout keyDevice keyDataType keyShape
key
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward kInProj
mhaKInProj
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim keySeqDim
keySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headDim
mhaHeadDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headEmbedDim
mhaHeadEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
kt :: IxStateT
m
(Generator qGeneratorOutputDevice)
(Generator kGeneratorOutputDevice)
(Tensor
qRequiresGradient
kLayout
kDevice
kDataType
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF kShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF keySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
kShape0
('Shape '[batchDim, keySeqDim, headDim, headEmbedDim])))))
kt = IxStateT
m
(Generator qGeneratorOutputDevice)
(Generator kGeneratorOutputDevice)
(Tensor
qRequiresGradient
kLayout
kDevice
kDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF kShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF keySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
kShape0
('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
k forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @3))
weights :: IxStateT
m
(Generator generatorDevice)
(Generator weightsGeneratorOutputDevice)
(Tensor
weightsRequiresGradient
weightsLayout
weightsDevice
weightsDataType
weightsShape)
weights =
(,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator qGeneratorOutputDevice)
(Tensor
qRequiresGradient
qLayout
qDevice
qDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF qShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF querySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
qShape0
('Shape '[batchDim, querySeqDim, headDim, headEmbedDim]))))
q forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator qGeneratorOutputDevice)
(Generator kGeneratorOutputDevice)
(Tensor
qRequiresGradient
kLayout
kDevice
kDataType
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF kShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF keySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
kShape0
('Shape '[batchDim, keySeqDim, headDim, headEmbedDim])))))
kt
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(gradient' :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (layout' :: Layout LayoutType)
(device :: Device (DeviceType Nat))
(device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
(dataType' :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ MatmulF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
matmul
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
MultiHeadAttentionHasScaling
MultiHeadAttentionWithoutScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
MultiHeadAttentionHasScaling
MultiHeadAttentionWithQueryScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
MultiHeadAttentionHasScaling
MultiHeadAttentionWithWeightScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
)
MultiHeadAttentionHasScaling
mhaScaling
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
`add` Tensor
attentionBiasRequiresGradient
attentionBiasLayout
attentionBiasDevice
attentionBiasDataType
attentionBiasShape
attentionBias)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
softmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @3))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward dropout
mhaDropout
v :: IxStateT
m
(Generator weightsGeneratorOutputDevice)
(Generator vGeneratorOutputDevice)
(Tensor
weightsRequiresGradient
vLayout
vDevice
vDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF vShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF keySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
vShape0
('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
v =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
valueRequiresGradient
valueLayout
valueDevice
valueDataType
valueShape
value
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward vInProj
mhaVInProj
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim keySeqDim
keySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headDim
mhaHeadDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim headEmbedDim
mhaHeadEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
in (,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator weightsGeneratorOutputDevice)
(Tensor
weightsRequiresGradient
weightsLayout
weightsDevice
weightsDataType
weightsShape)
weights forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator weightsGeneratorOutputDevice)
(Generator vGeneratorOutputDevice)
(Tensor
weightsRequiresGradient
vLayout
vDevice
vDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(ReshapeImplF
(NumelF vShape0)
(LiftTimesMaybe
(NumelDimF batchDim)
(LiftTimesMaybe
(NumelDimF keySeqDim)
(LiftTimesMaybe
(NumelDimF headDim)
(LiftTimesMaybe (NumelDimF headEmbedDim) ('Just 1)))))
vShape0
('Shape '[batchDim, keySeqDim, headDim, headEmbedDim]))))
v
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(gradient' :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (layout' :: Layout LayoutType)
(device :: Device (DeviceType Nat))
(device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
(dataType' :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ MatmulF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
matmul
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1)) (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim (forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim querySeqDim
querySeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim embedDim
mhaEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outProj
mhaOutProj