{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeApplications #-}
module Torch.GraduallyTyped.NN.Transformer.Pegasus.XSum where
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat)
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice)
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.Pegasus.Common (PegasusModelF, pegasusModelSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead, TransformerHead)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Size (..))
type PegasusXSumNumLayers = 16
pegasusXSumNumLayers :: SNat PegasusXSumNumLayers
pegasusXSumNumLayers :: SNat PegasusXSumNumLayers
pegasusXSumNumLayers = forall {k} (a :: k). SingI a => Sing a
sing @PegasusXSumNumLayers
type PegasusXSumHeadDim = 'Dim ('Name "*") ('Size 16)
type PegasusXSumHeadEmbedDim = 'Dim ('Name "*") ('Size 64)
type PegasusXSumEmbedDim = 'Dim ('Name "*") ('Size 1024)
type PegasusXSumInputEmbedDim = 'Dim ('Name "*") ('Size 1024)
pegasusXSumInputEmbedDim :: SDim PegasusXSumInputEmbedDim
pegasusXSumInputEmbedDim :: SDim PegasusXSumInputEmbedDim
pegasusXSumInputEmbedDim = forall {k} (a :: k). SingI a => Sing a
sing
type PegasusXSumFFNDim = 'Dim ('Name "*") ('Size 4096)
type PegasusXSumVocabDim = 'Dim ('Name "*") ('Size 96103)
pegasusXSumVocabDim :: SDim PegasusXSumVocabDim
pegasusXSumVocabDim :: SDim PegasusXSumVocabDim
pegasusXSumVocabDim = forall {k} (a :: k). SingI a => Sing a
sing
type PegasusXSum
(transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(hasDropout :: HasDropout) =
PegasusModelF transformerHead PegasusXSumNumLayers gradient device PegasusXSumHeadDim PegasusXSumHeadEmbedDim PegasusXSumEmbedDim PegasusXSumInputEmbedDim PegasusXSumFFNDim PegasusXSumVocabDim hasDropout
pegasusXSumSpec ::
STransformerHead transformerHead ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (PegasusXSum transformerHead gradient device hasDropout)
pegasusXSumSpec :: forall (transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural)) (hasDropout :: HasDropout).
STransformerHead transformerHead
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(PegasusXSum transformerHead gradient device hasDropout)
pegasusXSumSpec STransformerHead transformerHead
transformerHead = forall (transformerHead :: TransformerHead) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(PegasusModelF
transformerHead
numLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
hasDropout)
pegasusModelSpec STransformerHead transformerHead
transformerHead SNat PegasusXSumNumLayers
pegasusXSumNumLayers