{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
module Torch.GraduallyTyped.NN.Transformer.BART.Base where
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat)
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice)
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.BART.Common (BARTModelF, bartModelSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead, TransformerHead)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Size (..))
type BARTBaseNumLayers = 6
bartBaseNumLayers :: SNat BARTBaseNumLayers
bartBaseNumLayers :: SNat BARTBaseNumLayers
bartBaseNumLayers = forall {k} (a :: k). SingI a => Sing a
sing
type BARTBaseHeadDim = 'Dim ('Name "*") ('Size 12)
bartBaseHeadDim :: SDim BARTBaseHeadDim
bartBaseHeadDim :: SDim BARTBaseHeadDim
bartBaseHeadDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTBaseHeadEmbedDim = 'Dim ('Name "*") ('Size 64)
bartBaseHeadEmbedDim :: SDim BARTBaseHeadEmbedDim
bartBaseHeadEmbedDim :: SDim BARTBaseHeadEmbedDim
bartBaseHeadEmbedDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTBaseEmbedDim = 'Dim ('Name "*") ('Size 768)
bartBaseEmbedDim :: SDim BARTBaseEmbedDim
bartBaseEmbedDim :: SDim BARTBaseInputEmbedDim
bartBaseEmbedDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTBaseInputEmbedDim = 'Dim ('Name "*") ('Size 768)
bartBaseInputEmbedDim :: SDim BARTBaseInputEmbedDim
bartBaseInputEmbedDim :: SDim BARTBaseInputEmbedDim
bartBaseInputEmbedDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTBaseFFNDim = 'Dim ('Name "*") ('Size 3072)
bartBaseFFNDim :: SDim BARTBaseFFNDim
bartBaseFFNDim :: SDim BARTBaseFFNDim
bartBaseFFNDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTBaseVocabDim = 'Dim ('Name "*") ('Size 50265)
bartBaseVocabDim :: SDim BARTBaseVocabDim
bartBaseVocabDim :: SDim BARTBaseVocabDim
bartBaseVocabDim = forall {k} (a :: k). SingI a => Sing a
sing
type BARTBase
(transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(hasDropout :: HasDropout) =
BARTModelF transformerHead BARTBaseNumLayers gradient device BARTBaseHeadDim BARTBaseHeadEmbedDim BARTBaseEmbedDim BARTBaseInputEmbedDim BARTBaseFFNDim BARTBaseVocabDim hasDropout
bartBaseSpec ::
STransformerHead transformerHead ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (BARTBase transformerHead gradient device hasDropout)
bartBaseSpec :: forall (transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural)) (hasDropout :: HasDropout).
STransformerHead transformerHead
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec (BARTBase transformerHead gradient device hasDropout)
bartBaseSpec STransformerHead transformerHead
transformerHead = forall (transformerHead :: TransformerHead) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(BARTModelF
transformerHead
numLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
hasDropout)
bartModelSpec STransformerHead transformerHead
transformerHead SNat BARTBaseNumLayers
bartBaseNumLayers