{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
module Torch.GraduallyTyped.NN.Transformer.BART.Large where
import Data.Singletons (SingI (..))
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import GHC.TypeLits (Nat)
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice)
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.BART.Common (BARTModelF, bartModelSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead, TransformerHead)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Size (..))
type BARTLargeNumLayers = 12
bartLargeNumLayers :: SNat BARTLargeNumLayers
bartLargeNumLayers :: SNat BARTLargeNumLayers
bartLargeNumLayers = forall {k} (a :: k). SingI a => Sing a
sing
type BARTLargeHeadDim = 'Dim ('Name "*") ('Size 16)
bartLargeHeadDim :: SDim BARTLargeHeadDim
bartLargeHeadDim :: SDim BARTLargeHeadDim
bartLargeHeadDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTLargeHeadEmbedDim = 'Dim ('Name "*") ('Size 64)
bartLargeHeadEmbedDim :: SDim BARTLargeHeadEmbedDim
bartLargeHeadEmbedDim :: SDim BARTLargeHeadEmbedDim
bartLargeHeadEmbedDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTLargeEmbedDim = 'Dim ('Name "*") ('Size 1024)
bartLargeEmbedDim :: SDim BARTLargeEmbedDim
bartLargeEmbedDim :: SDim BARTLargeInputEmbedDim
bartLargeEmbedDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTLargeInputEmbedDim = 'Dim ('Name "*") ('Size 1024)
bartLargeInputEmbedDim :: SDim BARTLargeInputEmbedDim
bartLargeInputEmbedDim :: SDim BARTLargeInputEmbedDim
bartLargeInputEmbedDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTLargeFFNDim = 'Dim ('Name "*") ('Size 4096)
bartLargeFFNDim :: SDim BARTLargeFFNDim
bartLargeFFNDim :: SDim BARTLargeFFNDim
bartLargeFFNDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTLargeVocabDim = 'Dim ('Name "*") ('Size 50265)
bartLargeVocabDim :: SDim BARTLargeVocabDim
bartLargeVocabDim :: SDim BARTLargeVocabDim
bartLargeVocabDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTLarge
(transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(hasDropout :: HasDropout) =
BARTModelF transformerHead BARTLargeNumLayers gradient device BARTLargeHeadDim BARTLargeHeadEmbedDim BARTLargeEmbedDim BARTLargeInputEmbedDim BARTLargeFFNDim BARTLargeVocabDim hasDropout
bartLargeSpec ::
STransformerHead transformerHead ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (BARTLarge transformerHead gradient device hasDropout)
bartLargeSpec :: forall (transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural)) (hasDropout :: HasDropout).
STransformerHead transformerHead
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec (BARTLarge transformerHead gradient device hasDropout)
bartLargeSpec STransformerHead transformerHead
transformerHead = forall (transformerHead :: TransformerHead) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(BARTModelF
transformerHead
numLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
hasDropout)
bartModelSpec STransformerHead transformerHead
transformerHead SNat BARTLargeNumLayers
bartLargeNumLayers