{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Torch.GraduallyTyped.NN.Transformer.GTransformer where
import Control.Monad.Indexed ((>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed (IxPointed (ireturn))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Sparse (Embedding (..), EmbeddingSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GStack (DecoderStackF, EncoderStackF, decoderStackSpec, encoderStackSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (SNothing))
import Torch.GraduallyTyped.Prelude.TypeLits (SNat (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim, SSelectDim (..), SShape (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (TransposeF, UnsqueezeF, sTranspose, sUnsqueeze, transpose, unsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
data
GTransformer
(posEnc :: Type)
(relPosEnc :: Type)
(initialLayerNorm :: Type)
(initialDropout :: Type)
(stack :: Type)
(finalLayerNorm :: Type)
(finalDropout :: Type)
where
GTransformer ::
forall posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout.
{
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> posEnc
tPosEnc :: posEnc,
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> relPosEnc
tRelPosEnc :: relPosEnc,
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialLayerNorm
tInitialLayerNorm :: initialLayerNorm,
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialDropout
tInitialDropout :: initialDropout,
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> stack
tStack :: stack,
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalLayerNorm
tFinalLayerNorm :: finalLayerNorm,
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalDropout
tFinalDropout :: finalDropout
} ->
GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout
deriving stock (GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Eq posEnc, Eq relPosEnc, Eq initialLayerNorm, Eq initialDropout,
Eq stack, Eq finalLayerNorm, Eq finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
/= :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
$c/= :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Eq posEnc, Eq relPosEnc, Eq initialLayerNorm, Eq initialDropout,
Eq stack, Eq finalLayerNorm, Eq finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
== :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
$c== :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Eq posEnc, Eq relPosEnc, Eq initialLayerNorm, Eq initialDropout,
Eq stack, Eq finalLayerNorm, Eq finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
Eq, GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {posEnc} {relPosEnc} {initialLayerNorm} {initialDropout}
{stack} {finalLayerNorm} {finalDropout}.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
Eq
(GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout)
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Ordering
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
min :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
$cmin :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
max :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
$cmax :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
>= :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
$c>= :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
> :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
$c> :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
<= :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
$c<= :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
< :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
$c< :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Bool
compare :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Ordering
$ccompare :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Ord posEnc, Ord relPosEnc, Ord initialLayerNorm,
Ord initialDropout, Ord stack, Ord finalLayerNorm,
Ord finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Ordering
Ord, Int
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
Show initialDropout, Show stack, Show finalLayerNorm,
Show finalDropout) =>
Int
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> ShowS
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
Show initialDropout, Show stack, Show finalLayerNorm,
Show finalDropout) =>
[GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout]
-> ShowS
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
Show initialDropout, Show stack, Show finalLayerNorm,
Show finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> String
showList :: [GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout]
-> ShowS
$cshowList :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
Show initialDropout, Show stack, Show finalLayerNorm,
Show finalDropout) =>
[GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout]
-> ShowS
show :: GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> String
$cshow :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
Show initialDropout, Show stack, Show finalLayerNorm,
Show finalDropout) =>
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> String
showsPrec :: Int
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> ShowS
$cshowsPrec :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
(Show posEnc, Show relPosEnc, Show initialLayerNorm,
Show initialDropout, Show stack, Show finalLayerNorm,
Show finalDropout) =>
Int
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout x.
Rep
(GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout)
x
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout x.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Rep
(GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout)
x
$cto :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout x.
Rep
(GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout)
x
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
$cfrom :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout x.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> Rep
(GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout)
x
Generic)
type instance
ModelSpec (GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout) =
GTransformer (ModelSpec posEnc) (ModelSpec relPosEnc) (ModelSpec initialLayerNorm) (ModelSpec initialDropout) (ModelSpec stack) (ModelSpec finalLayerNorm) (ModelSpec finalDropout)
type family
TransformerEncoderF
(style :: TransformerStyle)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
TransformerEncoderF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout =
GTransformer
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim)
(TERelPosEncF style gradient device dataType headDim posEncDim)
(TEInitialLayerNormF style gradient device dataType inputEmbedDim)
(TEInitialDropoutF style hasDropout)
(TEStackF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim hasDropout)
(TEFinalLayerNormF style gradient device dataType inputEmbedDim)
(TEFinalDropoutF style hasDropout)
type family
TEPosEncF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TEPosEncF 'T5 _ _ _ _ _ = ()
TEPosEncF 'ByT5 gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'T5 gradient device dataType inputEmbedDim posEncDim
TEPosEncF 'BART gradient device dataType inputEmbedDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim inputEmbedDim 'Nothing)
TEPosEncF 'MBART gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'BART gradient device dataType inputEmbedDim posEncDim
TEPosEncF 'Pegasus gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'BART gradient device dataType inputEmbedDim posEncDim
TEPosEncF 'BERT gradient device dataType inputEmbedDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim inputEmbedDim 'Nothing)
TEPosEncF 'RoBERTa gradient device dataType inputEmbedDim posEncDim = TEPosEncF 'BERT gradient device dataType inputEmbedDim posEncDim
type family
TERelPosEncF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TERelPosEncF 'T5 gradient device dataType headDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim headDim 'Nothing)
TERelPosEncF 'ByT5 gradient device dataType headDim posEncDim = TERelPosEncF 'T5 gradient device dataType headDim posEncDim
TERelPosEncF 'BART _ _ _ _ _ = ()
TERelPosEncF 'MBART gradient device dataType headDim posEncDim = TERelPosEncF 'BART gradient device dataType headDim posEncDim
TERelPosEncF 'Pegasus gradient device dataType headDim posEncDim = TERelPosEncF 'BART gradient device dataType headDim posEncDim
TERelPosEncF 'BERT _ _ _ _ _ = ()
TERelPosEncF 'RoBERTa gradient device dataType headDim posEncDim = TERelPosEncF 'BERT gradient device dataType headDim posEncDim
type family
TEInitialLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TEInitialLayerNormF 'T5 _ _ _ _ = ()
TEInitialLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TEInitialLayerNormF 'T5 gradient device dataType inputEmbedDim
TEInitialLayerNormF 'BART gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
TEInitialLayerNormF 'MBART gradient device dataType inputEmbedDim = TEInitialLayerNormF 'BART gradient device dataType inputEmbedDim
TEInitialLayerNormF 'Pegasus _ _ _ _ = ()
TEInitialLayerNormF 'BERT gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
TEInitialLayerNormF 'RoBERTa gradient device dataType inputEmbedDim = TEInitialLayerNormF 'BERT gradient device dataType inputEmbedDim
type family
TEInitialDropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
TEInitialDropoutF 'T5 'WithDropout = Dropout
TEInitialDropoutF 'ByT5 'WithDropout = Dropout
TEInitialDropoutF 'BART 'WithDropout = Dropout
TEInitialDropoutF 'MBART 'WithDropout = Dropout
TEInitialDropoutF 'Pegasus 'WithDropout = Dropout
TEInitialDropoutF 'BERT 'WithDropout = Dropout
TEInitialDropoutF 'RoBERTa 'WithDropout = Dropout
TEInitialDropoutF _ 'WithoutDropout = ()
type family
TEStackF
(style :: TransformerStyle)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
TEStackF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim hasDropout =
NamedModel (EncoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim hasDropout)
type family
TEFinalLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TEFinalLayerNormF 'T5 gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[inputEmbedDim]))
TEFinalLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TEFinalLayerNormF 'T5 gradient device dataType inputEmbedDim
TEFinalLayerNormF 'BART _ _ _ _ = ()
TEFinalLayerNormF 'MBART gradient device dataType inputEmbedDim = TEFinalLayerNormF 'BART gradient device dataType inputEmbedDim
TEFinalLayerNormF 'Pegasus gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
TEFinalLayerNormF 'BERT _ _ _ _ = ()
TEFinalLayerNormF 'RoBERTa gradient device dataType inputEmbedDim = TEFinalLayerNormF 'BERT gradient device dataType inputEmbedDim
type family
TEFinalDropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
TEFinalDropoutF 'T5 'WithDropout = Dropout
TEFinalDropoutF 'ByT5 'WithDropout = Dropout
TEFinalDropoutF 'BART _ = ()
TEFinalDropoutF 'MBART _ = ()
TEFinalDropoutF 'Pegasus _ = ()
TEFinalDropoutF 'BERT _ = ()
TEFinalDropoutF 'RoBERTa _ = ()
TEFinalDropoutF _ 'WithoutDropout = ()
transformerEncoderSpec ::
forall style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout.
STransformerStyle style ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim inputEmbedDim ->
SDim ffnDim ->
SDim posEncDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (TransformerEncoderF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)
transformerEncoderSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(posEncDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(TransformerEncoderF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
hasDropout)
transformerEncoderSpec STransformerStyle style
style SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let posEncSpec :: STransformerStyle style
-> ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim)
posEncSpec STransformerStyle style
ST5 = ()
posEncSpec STransformerStyle style
SByT5 = ()
posEncSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
inputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
inputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
inputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.position_embeddings." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
inputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.position_embeddings." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
inputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
relPosEncSpec :: STransformerStyle style
-> ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
headDim
'Nothing
relPosEncSpec'
relPosEncSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
headDim
'Nothing
relPosEncSpec'
relPosEncSpec STransformerStyle style
SBART = ()
relPosEncSpec STransformerStyle style
SMBART = ()
relPosEncSpec STransformerStyle style
SPegasus = ()
relPosEncSpec STransformerStyle style
SBERT = ()
relPosEncSpec STransformerStyle style
SRoBERTa = ()
relPosEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
initialLayerNormSpec :: STransformerStyle style
-> ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim)
initialLayerNormSpec STransformerStyle style
ST5 = ()
initialLayerNormSpec STransformerStyle style
SByT5 = ()
initialLayerNormSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
initialLayerNormSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
initialLayerNormSpec STransformerStyle style
SPegasus = ()
initialLayerNormSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.LayerNorm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
initialLayerNormSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"embeddings.LayerNorm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
initialLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
initialDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
stackSpec :: STransformerStyle style
-> NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim))))))
stackSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'T5
ST5
stackSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'ByT5
SByT5
stackSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'BART
SBART
stackSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'MBART
SMBART
stackSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'Pegasus
SPegasus
stackSpec STransformerStyle style
SBERT = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder.layer." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'BERT
SBERT
stackSpec STransformerStyle style
SRoBERTa = forall model. Text -> model -> NamedModel model
NamedModel Text
"encoder.layer." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle 'RoBERTa
SRoBERTa
stackSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
finalLayerNormSpec :: STransformerStyle style
-> ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim)
finalLayerNormSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
finalLayerNormSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
finalLayerNormSpec STransformerStyle style
SBART = ()
finalLayerNormSpec STransformerStyle style
SMBART = ()
finalLayerNormSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
finalLayerNormSpec STransformerStyle style
SBERT = ()
finalLayerNormSpec STransformerStyle style
SRoBERTa = ()
finalLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
finalDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
finalDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
in forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
posEnc
-> relPosEnc
-> initialLayerNorm
-> initialDropout
-> stack
-> finalLayerNorm
-> finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
GTransformer
(STransformerStyle style
-> ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim)
posEncSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim)
initialLayerNormSpec STransformerStyle style
style)
(STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
(STransformerStyle style
-> NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim))))))
stackSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim)
finalLayerNormSpec STransformerStyle style
style)
(STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TEFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
where
stackSpec' :: _
stackSpec' :: STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))
stackSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(queryEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(EncoderStackF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
ffnDim
hasDropout)
encoderStackSpec STransformerStyle style
style' SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
layerNormSpec' :: _
layerNormSpec' :: SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' SHasBias hasBias
hasBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(normalizedShape :: Shape [Dim (Name Symbol) (Size Natural)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias hasBias
hasBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Natural)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
inputEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
relPosEncSpec' :: EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
headDim
'Nothing
relPosEncSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim headDim
headDim forall a. SMaybe 'Nothing
SNothing
posEncSpec' :: EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
inputEmbedDim
'Nothing
posEncSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing
type family
TransformerDecoderF
(style :: TransformerStyle)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(decoderInputEmbedDim :: Dim (Name Symbol) (Size Nat))
(encoderOutputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
TransformerDecoderF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim posEncDim hasDropout =
GTransformer
(TDPosEncF style gradient device dataType decoderInputEmbedDim posEncDim)
(TDRelPosEncF style gradient device dataType headDim posEncDim)
(TDInitialLayerNormF style gradient device dataType decoderInputEmbedDim)
(TDInitialDropoutF style hasDropout)
(TDStackF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim hasDropout)
(TDFinalLayerNormF style gradient device dataType decoderInputEmbedDim)
(TDFinalDropoutF style hasDropout)
type family
TDPosEncF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TDPosEncF 'T5 _ _ _ _ _ = ()
TDPosEncF 'ByT5 gradient device dataType inputEmbedDim posEncDim = TDPosEncF 'T5 gradient device dataType inputEmbedDim posEncDim
TDPosEncF 'BART gradient device dataType inputEmbedDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim inputEmbedDim 'Nothing)
TDPosEncF 'MBART gradient device dataType inputEmbedDim posEncDim = TDPosEncF 'BART gradient device dataType inputEmbedDim posEncDim
TDPosEncF 'Pegasus gradient device dataType inputEmbedDim posEncDim = TDPosEncF 'BART gradient device dataType inputEmbedDim posEncDim
type family
TDRelPosEncF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TDRelPosEncF 'T5 gradient device dataType headDim posEncDim = NamedModel (Embedding gradient ('Layout 'Dense) device dataType posEncDim headDim 'Nothing)
TDRelPosEncF 'ByT5 gradient device dataType headDim posEncDim = TDRelPosEncF 'T5 gradient device dataType headDim posEncDim
TDRelPosEncF 'BART _ _ _ _ _ = ()
TDRelPosEncF 'MBART gradient device dataType headDim posEncDim = TDRelPosEncF 'BART gradient device dataType headDim posEncDim
TDRelPosEncF 'Pegasus gradient device dataType headDim posEncDim = TDRelPosEncF 'BART gradient device dataType headDim posEncDim
type family
TDInitialLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TDInitialLayerNormF 'T5 _ _ _ _ = ()
TDInitialLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TDInitialLayerNormF 'T5 gradient device dataType inputEmbedDim
TDInitialLayerNormF 'BART gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
TDInitialLayerNormF 'MBART gradient device dataType inputEmbedDim = TDInitialLayerNormF 'BART gradient device dataType inputEmbedDim
TDInitialLayerNormF 'Pegasus _ _ _ _ = ()
type family
TDInitialDropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
TDInitialDropoutF 'T5 'WithDropout = Dropout
TDInitialDropoutF 'ByT5 'WithDropout = Dropout
TDInitialDropoutF 'BART 'WithDropout = Dropout
TDInitialDropoutF 'MBART 'WithDropout = Dropout
TDInitialDropoutF 'Pegasus 'WithDropout = Dropout
TDInitialDropoutF _ 'WithoutDropout = ()
type family
TDStackF
(style :: TransformerStyle)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(decoderInputEmbedDim :: Dim (Name Symbol) (Size Nat))
(encoderOutputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
TDStackF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim hasDropout =
NamedModel (DecoderStackF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim hasDropout)
type family
TDFinalLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
TDFinalLayerNormF 'T5 gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[inputEmbedDim]))
TDFinalLayerNormF 'ByT5 gradient device dataType inputEmbedDim = TDFinalLayerNormF 'T5 gradient device dataType inputEmbedDim
TDFinalLayerNormF 'BART _ _ _ _ = ()
TDFinalLayerNormF 'MBART gradient device dataType inputEmbedDim = TDFinalLayerNormF 'BART gradient device dataType inputEmbedDim
TDFinalLayerNormF 'Pegasus gradient device dataType inputEmbedDim = NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
type family
TDFinalDropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
TDFinalDropoutF 'T5 'WithDropout = Dropout
TDFinalDropoutF 'ByT5 'WithDropout = Dropout
TDFinalDropoutF 'BART _ = ()
TDFinalDropoutF 'MBART _ = ()
TDFinalDropoutF 'Pegasus _ = ()
TDFinalDropoutF _ 'WithoutDropout = ()
transformerDecoderSpec ::
forall style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim posEncDim hasDropout.
STransformerStyle style ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim decoderInputEmbedDim ->
SDim encoderOutputEmbedDim ->
SDim ffnDim ->
SDim posEncDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (TransformerDecoderF style numLayers gradient device dataType headDim headEmbedDim embedDim decoderInputEmbedDim encoderOutputEmbedDim ffnDim posEncDim hasDropout)
transformerDecoderSpec :: forall (style :: TransformerStyle) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(decoderInputEmbedDim :: Dim (Name Symbol) (Size Natural))
(encoderOutputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(posEncDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim decoderInputEmbedDim
-> SDim encoderOutputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(TransformerDecoderF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
decoderInputEmbedDim
encoderOutputEmbedDim
ffnDim
posEncDim
hasDropout)
transformerDecoderSpec STransformerStyle style
style SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim decoderInputEmbedDim
decoderInputEmbedDim SDim encoderOutputEmbedDim
encoderOutputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let posEncSpec :: STransformerStyle style
-> ModelSpec
(TDPosEncF
style gradient device dataType decoderInputEmbedDim posEncDim)
posEncSpec STransformerStyle style
ST5 = ()
posEncSpec STransformerStyle style
SByT5 = ()
posEncSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
decoderInputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
decoderInputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"embed_positions." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
decoderInputEmbedDim
'Nothing
posEncSpec'
posEncSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
posEncSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
posEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
relPosEncSpec :: STransformerStyle style
-> ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
headDim
'Nothing
relPosEncSpec'
relPosEncSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block.0.layer.0.SelfAttention.relative_attention_bias." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
headDim
'Nothing
relPosEncSpec'
relPosEncSpec STransformerStyle style
SBART = ()
relPosEncSpec STransformerStyle style
SMBART = ()
relPosEncSpec STransformerStyle style
SPegasus = ()
relPosEncSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
relPosEncSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
relPosEncSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
initialLayerNormSpec :: STransformerStyle style
-> ModelSpec
(TDInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
initialLayerNormSpec STransformerStyle style
ST5 = ()
initialLayerNormSpec STransformerStyle style
SByT5 = ()
initialLayerNormSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
initialLayerNormSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layernorm_embedding." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
initialLayerNormSpec STransformerStyle style
SPegasus = ()
initialLayerNormSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
initialLayerNormSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
initialLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
initialDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
initialDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithoutDropout = ()
initialDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
initialDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
initialDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
stackSpec :: STransformerStyle style
-> NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim))))))
stackSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'T5
ST5
stackSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"block." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'ByT5
SByT5
stackSpec STransformerStyle style
SBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'BART
SBART
stackSpec STransformerStyle style
SMBART = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'MBART
SMBART
stackSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layers." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle 'Pegasus
SPegasus
stackSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
stackSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
stackSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
finalLayerNormSpec :: STransformerStyle style
-> ModelSpec
(TDFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)
finalLayerNormSpec STransformerStyle style
ST5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
finalLayerNormSpec STransformerStyle style
SByT5 = forall model. Text -> model -> NamedModel model
NamedModel Text
"final_layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithoutBias
SWithoutBias
finalLayerNormSpec STransformerStyle style
SBART = ()
finalLayerNormSpec STransformerStyle style
SMBART = ()
finalLayerNormSpec STransformerStyle style
SPegasus = forall model. Text -> model -> NamedModel model
NamedModel Text
"layer_norm." forall a b. (a -> b) -> a -> b
$ forall {hasBias :: HasBias}.
SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias 'WithBias
SWithBias
finalLayerNormSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
finalLayerNormSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
finalLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
finalDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
finalDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
finalDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
finalDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
_ = ()
finalDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
finalDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
finalDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
in forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
posEnc
-> relPosEnc
-> initialLayerNorm
-> initialDropout
-> stack
-> finalLayerNorm
-> finalDropout
-> GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
GTransformer
(STransformerStyle style
-> ModelSpec
(TDPosEncF
style gradient device dataType decoderInputEmbedDim posEncDim)
posEncSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim)
relPosEncSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(TDInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
initialLayerNormSpec STransformerStyle style
style)
(STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDInitialDropoutF style hasDropout)
initialDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
(STransformerStyle style
-> NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim))))))
stackSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(TDFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)
finalLayerNormSpec STransformerStyle style
style)
(STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (TDFinalDropoutF style hasDropout)
finalDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
where
stackSpec' :: _
stackSpec' :: STransformerStyle style
-> GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(VInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device dataType decoderInputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device dataType decoderInputEmbedDim embedDim)
(KInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(VInProjF
style gradient device dataType encoderOutputEmbedDim embedDim)
(OutProjF
style gradient device dataType embedDim decoderInputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device dataType decoderInputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF
style gradient device dataType decoderInputEmbedDim)
(FFNInputTransformationF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType decoderInputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType decoderInputEmbedDim)))))
stackSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(queryEmbedDim :: Dim (Name Symbol) (Size Natural))
(keyEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim queryEmbedDim
-> SDim keyEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(DecoderStackF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
queryEmbedDim
keyEmbedDim
ffnDim
hasDropout)
decoderStackSpec STransformerStyle style
style' SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim decoderInputEmbedDim
decoderInputEmbedDim SDim encoderOutputEmbedDim
encoderOutputEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
layerNormSpec' :: _
layerNormSpec' :: SHasBias hasBias
-> LayerNormSpec
hasBias gradient device dataType ('Shape '[decoderInputEmbedDim])
layerNormSpec' SHasBias hasBias
hasBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(normalizedShape :: Shape [Dim (Name Symbol) (Size Natural)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias hasBias
hasBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Natural)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim decoderInputEmbedDim
decoderInputEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
relPosEncSpec' :: EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
headDim
'Nothing
relPosEncSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim headDim
headDim forall a. SMaybe 'Nothing
SNothing
posEncSpec' :: EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
posEncDim
decoderInputEmbedDim
'Nothing
posEncSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(paddingIdx :: Maybe Natural).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim posEncDim
posEncDim SDim decoderInputEmbedDim
decoderInputEmbedDim forall a. SMaybe 'Nothing
SNothing
instance
( HasInitialize posEnc generatorDevice posEnc' generatorDevice0,
HasInitialize relPosEnc generatorDevice0 relPosEnc' generatorDevice1,
HasInitialize initialLayerNorm generatorDevice1 initialLayerNorm' generatorDevice2,
HasInitialize initialDropout generatorDevice2 initialDropout' generatorDevice3,
HasInitialize stack generatorDevice3 stack' generatorDevice4,
HasInitialize finalLayerNorm generatorDevice4 finalLayerNorm' generatorDevice5,
HasInitialize finalDropout generatorDevice5 finalDropout' generatorOutputDevice
) =>
HasInitialize
(GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
generatorDevice
(GTransformer posEnc' relPosEnc' initialLayerNorm' initialDropout' stack' finalLayerNorm' finalDropout')
generatorOutputDevice
instance
( HasStateDict posEnc,
HasStateDict relPosEnc,
HasStateDict initialLayerNorm,
HasStateDict initialDropout,
HasStateDict stack,
HasStateDict finalLayerNorm,
HasStateDict finalDropout
) =>
HasStateDict (GTransformer posEnc relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
instance
( HasForward
posEnc
(Tensor posGradient posLayout posDevice posDataType posShape)
generatorDevice
(Tensor posEncGradient posEncLayout posEncDevice posEncDataType posEncShape)
generatorDevice0,
HasForward
initialLayerNorm
( Tensor
(inputGradient <|> posEncGradient)
(inputLayout <+> posEncLayout)
(inputDevice <+> posEncDevice)
(inputDataType <+> posEncDataType)
(BroadcastShapesF inputShape posEncShape)
)
generatorDevice0
tensor1
generatorDevice1,
Catch (BroadcastShapesF inputShape posEncShape),
HasForward
initialDropout
tensor1
generatorDevice1
tensor2
generatorDevice2,
HasForward
stack
( tensor2,
Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape)
)
generatorDevice2
tensor3
generatorDevice3,
Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape),
HasForward
finalLayerNorm
tensor3
generatorDevice3
tensor4
generatorDevice4,
HasForward
finalDropout
tensor4
generatorDevice4
output
generatorOutputDevice
) =>
HasForward
(GTransformer posEnc () initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
( Tensor inputGradient inputLayout inputDevice inputDataType inputShape,
Tensor posGradient posLayout posDevice posDataType posShape,
Tensor attentionMaskGradient attentionMaskLayout attentionMaskDevice attentionMaskDataType attentionMaskShape
)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
posEnc
()
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> (Tensor
inputGradient inputLayout inputDevice inputDataType inputShape,
Tensor posGradient posLayout posDevice posDataType posShape,
Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
attentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {posEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: ()
tPosEnc :: posEnc
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> posEnc
..} (Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input, Tensor posGradient posLayout posDevice posDataType posShape
pos, Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
attentionMaskShape
attentionMask) =
let attentionBias :: IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape))
attentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
attentionMaskShape
attentionMask
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor posGradient posLayout posDevice posDataType posShape
pos
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward posEnc
tPosEnc
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType)
(device' :: Device (DeviceType Natural))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Natural)])
(shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
(m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
`add`)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \tensor2
input' ->
IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape))
attentionBias
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape)
attentionBias' ->
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward stack
tStack (tensor2
input', Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape)
attentionBias')
)
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout
instance
( HasForward
initialLayerNorm
(Tensor inputGradient inputLayout inputDevice inputDataType inputShape)
generatorDevice
tensor0
generatorDevice0,
HasForward
initialDropout
tensor0
generatorDevice0
tensor1
generatorDevice1,
HasForward
relPosEnc
(Tensor relPosGradient relPosLayout relPosDevice relPosDataType relPosShape)
generatorDevice1
(Tensor relPosEncGradient relPosEncLayout relPosEncDevice relPosEncDataType relPosEncShape)
generatorDevice2,
HasForward
stack
( tensor1,
Tensor
(relPosEncGradient <|> attentionMaskGradient)
(relPosEncLayout <+> attentionMaskLayout)
(relPosEncDevice <+> attentionMaskDevice)
(relPosEncDataType <+> attentionMaskDataType)
(BroadcastShapesF doubleTransposedRelPosEncShape unsqueezedAttentionMaskShape)
)
generatorDevice2
tensor3
generatorDevice3,
transposedRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 2)) ('SelectDim ('ByIndex 3)) relPosEncShape,
Catch transposedRelPosEncShape,
doubleTransposedRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) transposedRelPosEncShape,
Catch doubleTransposedRelPosEncShape,
unsqueezedAttentionMaskShape ~ UnsqueezeF ('SelectDim ('ByIndex 1)) attentionMaskShape,
Catch unsqueezedAttentionMaskShape,
Catch (BroadcastShapesF doubleTransposedRelPosEncShape unsqueezedAttentionMaskShape),
HasForward
finalLayerNorm
tensor3
generatorDevice3
tensor4
generatorDevice4,
HasForward
finalDropout
tensor4
generatorDevice4
output
generatorOutputDevice
) =>
HasForward
(GTransformer () relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
( Tensor inputGradient inputLayout inputDevice inputDataType inputShape,
Tensor relPosGradient relPosLayout relPosDevice relPosDataType relPosShape,
Tensor attentionMaskGradient attentionMaskLayout attentionMaskDevice attentionMaskDataType attentionMaskShape
)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
()
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> (Tensor
inputGradient inputLayout inputDevice inputDataType inputShape,
Tensor
relPosGradient
relPosLayout
relPosDevice
relPosDataType
relPosShape,
Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
attentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {initialLayerNorm
initialDropout
relPosEnc
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: relPosEnc
tPosEnc :: ()
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> posEnc
..} (Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input, Tensor
relPosGradient relPosLayout relPosDevice relPosDataType relPosShape
relPos, Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
attentionMaskShape
attentionMask) =
let relPosBias :: IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
relPosEncGradient
relPosEncLayout
relPosEncDevice
relPosEncDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
relPosEncShape)))
relPosBias =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
relPosGradient relPosLayout relPosDevice relPosDataType relPosShape
relPos
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward relPosEnc
tRelPosEnc
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
(selectDim1 :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @2) (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @3)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
(selectDim1 :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @1) (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @2)
attentionBias :: IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
(Or
(Gradient RequiresGradient)
relPosEncGradient
attentionMaskGradient)
(Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
(Unify
(Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
(Unify (DataType DType) relPosEncDataType attentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
relPosEncShape))
unsqueezedAttentionMaskShape))
attentionBias =
IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
relPosEncGradient
relPosEncLayout
relPosEncDevice
relPosEncDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
relPosEncShape)))
relPosBias
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (selectDim :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @1) Tensor
attentionMaskGradient
attentionMaskLayout
attentionMaskDevice
attentionMaskDataType
attentionMaskShape
attentionMask forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType)
(device' :: Device (DeviceType Natural))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Natural)])
(shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
(m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
add
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\tensor1
input' -> IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
(Or
(Gradient RequiresGradient)
relPosEncGradient
attentionMaskGradient)
(Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
(Unify
(Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
(Unify (DataType DType) relPosEncDataType attentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
relPosEncShape))
unsqueezedAttentionMaskShape))
attentionBias forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor
(Or
(Gradient RequiresGradient)
relPosEncGradient
attentionMaskGradient)
(Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
(Unify
(Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
(Unify (DataType DType) relPosEncDataType attentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
relPosEncShape))
unsqueezedAttentionMaskShape)
attentionBias' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward stack
tStack (tensor1
input', Tensor
(Or
(Gradient RequiresGradient)
relPosEncGradient
attentionMaskGradient)
(Unify (Layout LayoutType) relPosEncLayout attentionMaskLayout)
(Unify
(Device (DeviceType Natural)) relPosEncDevice attentionMaskDevice)
(Unify (DataType DType) relPosEncDataType attentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
relPosEncShape))
unsqueezedAttentionMaskShape)
attentionBias')))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout
instance
( HasForward
posEnc
(Tensor decoderPosGradient decoderPosLayout decoderPosDevice decoderPosDataType decoderPosShape)
generatorDevice
(Tensor decoderPosEncGradient decoderPosEncLayout decoderPosEncDevice decoderPosEncDataType decoderPosEncShape)
generatorDevice0,
HasForward
initialLayerNorm
( Tensor
(decoderInputGradient <|> decoderPosEncGradient)
(decoderInputLayout <+> decoderPosEncLayout)
(decoderInputDevice <+> decoderPosEncDevice)
(decoderInputDataType <+> decoderPosEncDataType)
(BroadcastShapesF decoderInputShape decoderPosEncShape)
)
generatorDevice0
tensor1
generatorDevice1,
HasForward
initialDropout
tensor1
generatorDevice1
tensor2
generatorDevice2,
HasForward
stack
( tensor2,
Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape),
Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape)
)
generatorDevice2
tensor3
generatorDevice3,
Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape),
Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape),
Catch (BroadcastShapesF decoderInputShape decoderPosEncShape),
HasForward
finalLayerNorm
tensor3
generatorDevice3
tensor4
generatorDevice4,
HasForward
finalDropout
tensor4
generatorDevice4
output
generatorOutputDevice
) =>
HasForward
(GTransformer posEnc () initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
( Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape,
Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
Tensor decoderPosGradient decoderPosLayout decoderPosDevice decoderPosDataType decoderPosShape,
Tensor decoderAttentionMaskGradient decoderAttentionMaskLayout decoderAttentionMaskDevice decoderAttentionMaskDataType decoderAttentionMaskShape,
Tensor crossAttentionMaskGradient crossAttentionMaskLayout crossAttentionMaskDevice crossAttentionMaskDataType crossAttentionMaskShape
)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
posEnc
()
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> (Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape,
Tensor
encoderOutputGradient
encoderOutputLayout
encoderOutputDevice
encoderOutputDataType
encoderOutputShape,
Tensor
decoderPosGradient
decoderPosLayout
decoderPosDevice
decoderPosDataType
decoderPosShape,
Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
decoderAttentionMaskShape,
Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
crossAttentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {posEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: ()
tPosEnc :: posEnc
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> posEnc
..} (Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
decoderInput, Tensor
encoderOutputGradient
encoderOutputLayout
encoderOutputDevice
encoderOutputDataType
encoderOutputShape
encoderOutput, Tensor
decoderPosGradient
decoderPosLayout
decoderPosDevice
decoderPosDataType
decoderPosShape
decoderPos, Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
decoderAttentionMaskShape
decoderAttentionMask, Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
crossAttentionMaskShape
crossAttentionMask) =
let decoderAttentionBias :: IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape))
decoderAttentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
decoderAttentionMaskShape
decoderAttentionMask
crossAttentionBias :: IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape))
crossAttentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
crossAttentionMaskShape
crossAttentionMask
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
decoderPosGradient
decoderPosLayout
decoderPosDevice
decoderPosDataType
decoderPosShape
decoderPos
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward posEnc
tPosEnc
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
decoderInput forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType)
(device' :: Device (DeviceType Natural))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Natural)])
(shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
(m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
`add`)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \tensor2
decoderInput' ->
IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape))
decoderAttentionBias
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape)
decoderAttentionBias' ->
IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape))
crossAttentionBias
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape)
crossAttentionBias' ->
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
stack
tStack
( tensor2
decoderInput',
Tensor
encoderOutputGradient
encoderOutputLayout
encoderOutputDevice
encoderOutputDataType
encoderOutputShape
encoderOutput,
Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape)
decoderAttentionBias',
Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape)
crossAttentionBias'
)
)
)
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout
instance
( HasForward
initialLayerNorm
(Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape)
generatorDevice
tensor0
generatorDevice0,
HasForward
initialDropout
tensor0
generatorDevice0
tensor1
generatorDevice1,
HasForward
relPosEnc
(Tensor decoderRelPosGradient decoderRelPosLayout decoderRelPosDevice decoderRelPosDataType decoderRelPosShape)
generatorDevice1
(Tensor decoderRelPosEncGradient decoderRelPosEncLayout decoderRelPosEncDevice decoderRelPosEncDataType decoderRelPosEncShape)
generatorDevice2,
HasForward
stack
( tensor1,
Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
Tensor
(decoderRelPosEncGradient <|> decoderAttentionMaskGradient)
(decoderRelPosEncLayout <+> decoderAttentionMaskLayout)
(decoderRelPosEncDevice <+> decoderAttentionMaskDevice)
(decoderRelPosEncDataType <+> decoderAttentionMaskDataType)
(BroadcastShapesF doubleTransposedDecoderRelPosEncShape unsqueezedDecoderAttentionMaskShape),
Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
unsqueezedCrossAttentionMaskShape
)
generatorDevice2
tensor3
generatorDevice3,
transposedDecoderRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 2)) ('SelectDim ('ByIndex 3)) decoderRelPosEncShape,
Catch transposedDecoderRelPosEncShape,
doubleTransposedDecoderRelPosEncShape ~ TransposeF ('SelectDim ('ByIndex 1)) ('SelectDim ('ByIndex 2)) transposedDecoderRelPosEncShape,
Catch doubleTransposedDecoderRelPosEncShape,
unsqueezedDecoderAttentionMaskShape ~ UnsqueezeF ('SelectDim ('ByIndex 1)) decoderAttentionMaskShape,
Catch unsqueezedDecoderAttentionMaskShape,
unsqueezedCrossAttentionMaskShape ~ UnsqueezeF ('SelectDim ('ByIndex 1)) crossAttentionMaskShape,
Catch unsqueezedCrossAttentionMaskShape,
Catch (BroadcastShapesF doubleTransposedDecoderRelPosEncShape unsqueezedDecoderAttentionMaskShape),
HasForward
finalLayerNorm
tensor3
generatorDevice3
tensor4
generatorDevice4,
HasForward
finalDropout
tensor4
generatorDevice4
output
generatorOutputDevice
) =>
HasForward
(GTransformer () relPosEnc initialLayerNorm initialDropout stack finalLayerNorm finalDropout)
( Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape,
Tensor encoderOutputGradient encoderOutputLayout encoderOutputDevice encoderOutputDataType encoderOutputShape,
Tensor decoderRelPosGradient decoderRelPosLayout decoderRelPosDevice decoderRelPosDataType decoderRelPosShape,
Tensor decoderAttentionMaskGradient decoderAttentionMaskLayout decoderAttentionMaskDevice decoderAttentionMaskDataType decoderAttentionMaskShape,
Tensor crossAttentionMaskGradient crossAttentionMaskLayout crossAttentionMaskDevice crossAttentionMaskDataType crossAttentionMaskShape
)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformer
()
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> (Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape,
Tensor
encoderOutputGradient
encoderOutputLayout
encoderOutputDevice
encoderOutputDataType
encoderOutputShape,
Tensor
decoderRelPosGradient
decoderRelPosLayout
decoderRelPosDevice
decoderRelPosDataType
decoderRelPosShape,
Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
decoderAttentionMaskShape,
Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
crossAttentionMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformer {initialLayerNorm
initialDropout
relPosEnc
stack
finalLayerNorm
finalDropout
()
tFinalDropout :: finalDropout
tFinalLayerNorm :: finalLayerNorm
tStack :: stack
tInitialDropout :: initialDropout
tInitialLayerNorm :: initialLayerNorm
tRelPosEnc :: relPosEnc
tPosEnc :: ()
tFinalDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalDropout
tFinalLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> finalLayerNorm
tStack :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> stack
tInitialDropout :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialDropout
tInitialLayerNorm :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> initialLayerNorm
tRelPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> relPosEnc
tPosEnc :: forall posEnc relPosEnc initialLayerNorm initialDropout stack
finalLayerNorm finalDropout.
GTransformer
posEnc
relPosEnc
initialLayerNorm
initialDropout
stack
finalLayerNorm
finalDropout
-> posEnc
..} (Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
decoderInput, Tensor
encoderOutputGradient
encoderOutputLayout
encoderOutputDevice
encoderOutputDataType
encoderOutputShape
encoderOutput, Tensor
decoderRelPosGradient
decoderRelPosLayout
decoderRelPosDevice
decoderRelPosDataType
decoderRelPosShape
decoderRelPos, Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
decoderAttentionMaskShape
decoderAttentionMask, Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
crossAttentionMaskShape
crossAttentionMask) =
let decoderRelPosBias :: IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
decoderRelPosEncGradient
decoderRelPosEncLayout
decoderRelPosEncDevice
decoderRelPosEncDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
decoderRelPosEncShape)))
decoderRelPosBias =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
decoderRelPosGradient
decoderRelPosLayout
decoderRelPosDevice
decoderRelPosDataType
decoderRelPosShape
decoderRelPos
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward relPosEnc
tRelPosEnc
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
(selectDim1 :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
SingI selectDim0, SingI selectDim1, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
transpose @('SelectDim ('ByIndex 2)) @('SelectDim ('ByIndex 3))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (selectDim0 :: SelectDim (By Symbol Natural))
(selectDim1 :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
SingI selectDim0, SingI selectDim1, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
transpose @('SelectDim ('ByIndex 1)) @('SelectDim ('ByIndex 2))
decoderAttentionBias :: IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
(Or
(Gradient RequiresGradient)
decoderRelPosEncGradient
decoderAttentionMaskGradient)
(Unify
(Layout LayoutType)
decoderRelPosEncLayout
decoderAttentionMaskLayout)
(Unify
(Device (DeviceType Natural))
decoderRelPosEncDevice
decoderAttentionMaskDevice)
(Unify
(DataType DType)
decoderRelPosEncDataType
decoderAttentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
decoderRelPosEncShape))
unsqueezedDecoderAttentionMaskShape))
decoderAttentionBias =
IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
decoderRelPosEncGradient
decoderRelPosEncLayout
decoderRelPosEncDevice
decoderRelPosEncDataType
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
decoderRelPosEncShape)))
decoderRelPosBias
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (selectDim :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall (by :: By Symbol Natural).
SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Natural). KnownNat index => SBy ('ByIndex index)
SByIndex @1) Tensor
decoderAttentionMaskGradient
decoderAttentionMaskLayout
decoderAttentionMaskDevice
decoderAttentionMaskDataType
decoderAttentionMaskShape
decoderAttentionMask forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType)
(device' :: Device (DeviceType Natural))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Natural)])
(shape'' :: Shape [Dim (Name Symbol) (Size Natural)])
(m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
add
crossAttentionBias :: IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
unsqueezedCrossAttentionMaskShape)
crossAttentionBias = forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ forall (selectDim :: SelectDim (By Symbol Natural))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Natural)])
(shape' :: Shape [Dim (Name Symbol) (Size Natural)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
crossAttentionMaskShape
crossAttentionMask
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
decoderInput
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialLayerNorm
tInitialLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward initialDropout
tInitialDropout
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \tensor1
decoderInput' ->
IxStateT
m
(Generator generatorDevice1)
(Generator generatorDevice2)
(Tensor
(Or
(Gradient RequiresGradient)
decoderRelPosEncGradient
decoderAttentionMaskGradient)
(Unify
(Layout LayoutType)
decoderRelPosEncLayout
decoderAttentionMaskLayout)
(Unify
(Device (DeviceType Natural))
decoderRelPosEncDevice
decoderAttentionMaskDevice)
(Unify
(DataType DType)
decoderRelPosEncDataType
decoderAttentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
decoderRelPosEncShape))
unsqueezedDecoderAttentionMaskShape))
decoderAttentionBias
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
(Or
(Gradient RequiresGradient)
decoderRelPosEncGradient
decoderAttentionMaskGradient)
(Unify
(Layout LayoutType)
decoderRelPosEncLayout
decoderAttentionMaskLayout)
(Unify
(Device (DeviceType Natural))
decoderRelPosEncDevice
decoderAttentionMaskDevice)
(Unify
(DataType DType)
decoderRelPosEncDataType
decoderAttentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
decoderRelPosEncShape))
unsqueezedDecoderAttentionMaskShape)
decoderAttentionBias' ->
IxStateT
m
(Generator generatorDevice2)
(Generator generatorDevice2)
(Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
unsqueezedCrossAttentionMaskShape)
crossAttentionBias
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
unsqueezedCrossAttentionMaskShape
crossAttentionBias' ->
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
stack
tStack
( tensor1
decoderInput',
Tensor
encoderOutputGradient
encoderOutputLayout
encoderOutputDevice
encoderOutputDataType
encoderOutputShape
encoderOutput,
Tensor
(Or
(Gradient RequiresGradient)
decoderRelPosEncGradient
decoderAttentionMaskGradient)
(Unify
(Layout LayoutType)
decoderRelPosEncLayout
decoderAttentionMaskLayout)
(Unify
(Device (DeviceType Natural))
decoderRelPosEncDevice
decoderAttentionMaskDevice)
(Unify
(DataType DType)
decoderRelPosEncDataType
decoderAttentionMaskDataType)
(BroadcastShapesF
(TransposeF
('SelectDim ('ByIndex 1))
('SelectDim ('ByIndex 2))
(TransposeF
('SelectDim ('ByIndex 2))
('SelectDim ('ByIndex 3))
decoderRelPosEncShape))
unsqueezedDecoderAttentionMaskShape)
decoderAttentionBias',
Tensor
crossAttentionMaskGradient
crossAttentionMaskLayout
crossAttentionMaskDevice
crossAttentionMaskDataType
unsqueezedCrossAttentionMaskShape
crossAttentionBias'
)
)
)
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalLayerNorm
tFinalLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Natural))
output (generatorOutputDevice :: Device (DeviceType Natural))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward finalDropout
tFinalDropout