{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
module Torch.GraduallyTyped.NN.Transformer.T5.Large where
import Data.Singletons (SingI (sing))
import GHC.TypeLits (Nat)
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice)
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.T5.Common (T5ModelF, t5ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead, STransformerStyle (ST5), TransformerHead, TransformerStyle (T5))
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), Size (..))
type T5LargeNumLayers = 24
t5LargeNumLayers :: SNat T5LargeNumLayers
t5LargeNumLayers :: SNat T5LargeNumLayers
t5LargeNumLayers = forall {k} (a :: k). SingI a => Sing a
sing
type T5LargeHeadDim = 'Dim ('Name "*") ('Size 16)
type T5LargeHeadEmbedDim = 'Dim ('Name "*") ('Size 64)
type T5LargeEmbedDim = 'Dim ('Name "*") ('Size 1024)
type T5LargeInputEmbedDim = 'Dim ('Name "*") ('Size 1024)
type T5LargeFFNDim = 'Dim ('Name "*") ('Size 4096)
type T5LargeVocabDim = 'Dim ('Name "*") ('Size 32128)
type T5Large
(transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(hasDropout :: HasDropout) =
T5ModelF 'T5 transformerHead T5LargeNumLayers T5LargeNumLayers gradient device T5LargeHeadDim T5LargeHeadEmbedDim T5LargeEmbedDim T5LargeInputEmbedDim T5LargeFFNDim T5LargeVocabDim hasDropout
t5LargeSpec ::
STransformerHead transformerHead ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (T5Large transformerHead gradient device hasDropout)
t5LargeSpec :: forall (transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural)) (hasDropout :: HasDropout).
STransformerHead transformerHead
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec (T5Large transformerHead gradient device hasDropout)
t5LargeSpec STransformerHead transformerHead
transformerHead = forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
(numDecoderLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(T5ModelF
style
transformerHead
numEncoderLayers
numDecoderLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
hasDropout)
t5ModelSpec STransformerStyle 'T5
ST5 STransformerHead transformerHead
transformerHead SNat T5LargeNumLayers
t5LargeNumLayers SNat T5LargeNumLayers
t5LargeNumLayers