{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Torch.GraduallyTyped.NN.Transformer.T5.Common where
import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (GEncoderDecoderTransformerF, GSimplifiedEncoderDecoderTransformer (..), encoderDecoderTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkRelPos (..), MkTransformerAttentionMask (..), MkTransformerCrossAttentionMask (..), MkTransformerDecoderAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead (), STransformerStyle (SByT5, ST5), ShiftRight (..), TransformerHead (..), TransformerStyle (ByT5, T5), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>))
type T5DType = 'Float
t5DType :: SDType T5DType
t5DType :: SDType T5DType
t5DType = forall {k} (a :: k). SingI a => Sing a
sing @T5DType
type T5DataType = 'DataType T5DType
t5DataType :: SDataType T5DataType
t5DataType :: SDataType T5DataType
t5DataType = forall {k} (a :: k). SingI a => Sing a
sing @T5DataType
t5DropoutP :: Double
t5DropoutP :: Double
t5DropoutP = Double
0.1
type T5RelPosEncBucketDim = 'Dim ('Name "*") ('Size 32)
t5RelPosEncBucketDim :: SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim :: SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim = forall {k} (a :: k). SingI a => Sing a
sing @T5RelPosEncBucketDim
t5Eps :: Double
t5Eps :: Double
t5Eps = Double
1e-6
t5MaxDistance :: Int
t5MaxDistance :: Int
t5MaxDistance = Int
128
t5PadTokenId :: Int
t5PadTokenId :: Int
t5PadTokenId = Int
0
t5BOSTokenId :: Int
t5BOSTokenId :: Int
t5BOSTokenId = Int
t5PadTokenId
t5EOSTokenId :: Int
t5EOSTokenId :: Int
t5EOSTokenId = Int
1
t5AttentionMaskBias :: Double
t5AttentionMaskBias :: Double
t5AttentionMaskBias = -Double
10000
type T5ModelF ::
TransformerStyle ->
TransformerHead ->
Nat ->
Nat ->
Gradient RequiresGradient ->
Device (DeviceType Nat) ->
Dim (Name Symbol) (Size Nat) ->
Dim (Name Symbol) (Size Nat) ->
Dim (Name Symbol) (Size Nat) ->
Dim (Name Symbol) (Size Nat) ->
Dim (Name Symbol) (Size Nat) ->
Dim (Name Symbol) (Size Nat) ->
HasDropout ->
Type
type family
T5ModelF style transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout
where
T5ModelF 'T5 transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
GSimplifiedEncoderDecoderTransformer
(GEncoderDecoderTransformerF 'T5 transformerHead numEncoderLayers numDecoderLayers gradient device T5DataType headDim headEmbedDim embedDim inputEmbedDim ffnDim T5RelPosEncBucketDim vocabDim hasDropout)
(MkRelPos T5RelPosEncBucketDim)
(MkRelPos T5RelPosEncBucketDim)
MkTransformerPaddingMask
(MkTransformerAttentionMask T5DataType)
(MkTransformerCrossAttentionMask T5DataType)
(MkTransformerDecoderAttentionMask T5DataType)
T5ModelF 'ByT5 transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
GSimplifiedEncoderDecoderTransformer
(GEncoderDecoderTransformerF 'ByT5 transformerHead numEncoderLayers numDecoderLayers gradient device T5DataType headDim headEmbedDim embedDim inputEmbedDim ffnDim T5RelPosEncBucketDim vocabDim hasDropout)
(MkRelPos T5RelPosEncBucketDim)
(MkRelPos T5RelPosEncBucketDim)
MkTransformerPaddingMask
(MkTransformerAttentionMask T5DataType)
(MkTransformerCrossAttentionMask T5DataType)
(MkTransformerDecoderAttentionMask T5DataType)
t5ModelSpec ::
forall style transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout.
( SingI headDim,
SingI headEmbedDim,
SingI embedDim,
SingI inputEmbedDim,
SingI ffnDim,
SingI vocabDim
) =>
STransformerStyle style ->
STransformerHead transformerHead ->
SNat numEncoderLayers ->
SNat numDecoderLayers ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (T5ModelF style transformerHead numEncoderLayers numDecoderLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout)
t5ModelSpec :: forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
(numDecoderLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(T5ModelF
style
transformerHead
numEncoderLayers
numDecoderLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
hasDropout)
t5ModelSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead SNat numEncoderLayers
numEncoderLayers SNat numDecoderLayers
numDecoderLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
case STransformerStyle style
style of
STransformerStyle style
ST5 ->
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
(forall {style :: TransformerStyle}.
STransformerStyle style
-> GEncoderDecoderTransformer
inputEmbedDim
(NamedModel
(GTransformer
(ModelSpec
(TEPosEncF
style
gradient
device
T5DataType
inputEmbedDim
T5RelPosEncBucketDim))
(ModelSpec
(TERelPosEncF
style gradient device T5DataType headDim T5RelPosEncBucketDim))
(ModelSpec
(TEInitialLayerNormF
style gradient device T5DataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
(FFNInputTransformationF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device T5DataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device T5DataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))))
(NamedModel
(GTransformer
(ModelSpec
(TDPosEncF
style
gradient
device
T5DataType
inputEmbedDim
T5RelPosEncBucketDim))
(ModelSpec
(TDRelPosEncF
style gradient device T5DataType headDim T5RelPosEncBucketDim))
(ModelSpec
(TDInitialLayerNormF
style gradient device T5DataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
(FFNInputTransformationF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device T5DataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device T5DataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))))
(NamedModel
(EmbeddingSpec
gradient
('Layout 'Dense)
device
T5DataType
vocabDim
inputEmbedDim
'Nothing))
(ModelSpec
(EDTHeadF
style
transformerHead
gradient
device
T5DataType
inputEmbedDim
vocabDim))
modelSpec' STransformerStyle 'T5
ST5)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
t5BOSTokenId)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
(forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
(forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkDecoderRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
(Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
t5PadTokenId)
(forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
STransformerStyle style
SByT5 ->
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
(forall {style :: TransformerStyle}.
STransformerStyle style
-> GEncoderDecoderTransformer
inputEmbedDim
(NamedModel
(GTransformer
(ModelSpec
(TEPosEncF
style
gradient
device
T5DataType
inputEmbedDim
T5RelPosEncBucketDim))
(ModelSpec
(TERelPosEncF
style gradient device T5DataType headDim T5RelPosEncBucketDim))
(ModelSpec
(TEInitialLayerNormF
style gradient device T5DataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
(FFNInputTransformationF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device T5DataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device T5DataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))))
(NamedModel
(GTransformer
(ModelSpec
(TDPosEncF
style
gradient
device
T5DataType
inputEmbedDim
T5RelPosEncBucketDim))
(ModelSpec
(TDRelPosEncF
style gradient device T5DataType headDim T5RelPosEncBucketDim))
(ModelSpec
(TDInitialLayerNormF
style gradient device T5DataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
(FFNInputTransformationF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device T5DataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device T5DataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))))
(NamedModel
(EmbeddingSpec
gradient
('Layout 'Dense)
device
T5DataType
vocabDim
inputEmbedDim
'Nothing))
(ModelSpec
(EDTHeadF
style
transformerHead
gradient
device
T5DataType
inputEmbedDim
vocabDim))
modelSpec' STransformerStyle 'ByT5
SByT5)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
t5BOSTokenId)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
(forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
(forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Natural)).
SDim relPosEncBucketDim -> Int -> MkRelPos relPosEncBucketDim
MkDecoderRelPos SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim Int
t5MaxDistance)
(Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
t5PadTokenId)
(forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType T5DataType
t5DataType Double
t5AttentionMaskBias)
STransformerStyle style
_ -> forall a. HasCallStack => a
undefined
where
modelSpec' :: _
modelSpec' :: STransformerStyle style
-> GEncoderDecoderTransformer
inputEmbedDim
(NamedModel
(GTransformer
(ModelSpec
(TEPosEncF
style
gradient
device
T5DataType
inputEmbedDim
T5RelPosEncBucketDim))
(ModelSpec
(TERelPosEncF
style gradient device T5DataType headDim T5RelPosEncBucketDim))
(ModelSpec
(TEInitialLayerNormF
style gradient device T5DataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
(FFNInputTransformationF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device T5DataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device T5DataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))))
(NamedModel
(GTransformer
(ModelSpec
(TDPosEncF
style
gradient
device
T5DataType
inputEmbedDim
T5RelPosEncBucketDim))
(ModelSpec
(TDRelPosEncF
style gradient device T5DataType headDim T5RelPosEncBucketDim))
(ModelSpec
(TDInitialLayerNormF
style gradient device T5DataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF
style gradient device T5DataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(KInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(VInProjF
style gradient device T5DataType inputEmbedDim embedDim)
(OutProjF
style gradient device T5DataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF
style gradient device T5DataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device T5DataType inputEmbedDim)
(FFNInputTransformationF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device T5DataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device T5DataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device T5DataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))))
(NamedModel
(EmbeddingSpec
gradient
('Layout 'Dense)
device
T5DataType
vocabDim
inputEmbedDim
'Nothing))
(ModelSpec
(EDTHeadF
style
transformerHead
gradient
device
T5DataType
inputEmbedDim
vocabDim))
modelSpec' STransformerStyle style
style' =
forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
(numDecoderLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(posEncDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GEncoderDecoderTransformerF
style
transformerHead
numEncoderLayers
numDecoderLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
vocabDim
hasDropout)
encoderDecoderTransformerSpec
STransformerStyle style
style'
STransformerHead transformerHead
transformerHead
SNat numEncoderLayers
numEncoderLayers
SNat numDecoderLayers
numDecoderLayers
SGradient gradient
gradient
SDevice device
device
SDataType T5DataType
t5DataType
(forall {k} (a :: k). SingI a => Sing a
sing @headDim)
(forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
SDim T5RelPosEncBucketDim
t5RelPosEncBucketDim
(forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
SHasDropout hasDropout
hasDropout
Double
t5DropoutP
Double
t5Eps
mkT5Input ::
forall batchDim seqDim device m output.
( MonadThrow m,
SGetDim batchDim,
SGetDim seqDim,
Catch
( 'Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize
]
<+> 'Shape '[batchDim, seqDim]
),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])
) =>
SDim batchDim ->
SDim seqDim ->
SDevice device ->
[[Int]] ->
m output
mkT5Input :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input = forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
t5PadTokenId