{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Transformer.Pegasus.Common where
import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (GEncoderDecoderTransformerF, GSimplifiedEncoderDecoderTransformer (..), encoderDecoderTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkAbsPos (..), MkTransformerAttentionMask (..), MkTransformerCrossAttentionMask (..), MkTransformerDecoderAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead, STransformerStyle (SPegasus), ShiftRight (..), TransformerHead (..), TransformerStyle (Pegasus), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))
type PegasusDType = 'Float
pegasusDType :: SDType PegasusDType
pegasusDType :: SDType PegasusDType
pegasusDType = forall {k} (a :: k). SingI a => Sing a
sing @PegasusDType
type PegasusDataType = 'DataType PegasusDType
pegasusDataType :: SDataType PegasusDataType
pegasusDataType :: SDataType PegasusDataType
pegasusDataType = forall {k} (a :: k). SingI a => Sing a
sing @PegasusDataType
pegasusDropoutP :: Double
pegasusDropoutP :: Double
pegasusDropoutP = Double
0.1
type PegasusPosEncDim = 'Dim ('Name "*") ('Size 512)
pegasusPosEncDim :: SDim PegasusPosEncDim
pegasusPosEncDim :: SDim PegasusPosEncDim
pegasusPosEncDim = forall {k} (a :: k). SingI a => Sing a
sing @PegasusPosEncDim
pegasusEps :: Double
pegasusEps :: Double
pegasusEps = Double
1e-5
pegasusMaxPositionEmbeddings :: Int
pegasusMaxPositionEmbeddings :: Int
pegasusMaxPositionEmbeddings = Int
512
pegasusPadTokenId :: Int
pegasusPadTokenId :: Int
pegasusPadTokenId = Int
0
pegasusBOSTokenId :: Int
pegasusBOSTokenId :: Int
pegasusBOSTokenId = Int
pegasusPadTokenId
pegasusEOSTokenId :: Int
pegasusEOSTokenId :: Int
pegasusEOSTokenId = Int
1
pegasusAttentionMaskBias :: Double
pegasusAttentionMaskBias :: Double
pegasusAttentionMaskBias = -Double
10000
type family
PegasusModelF
(transformerHead :: TransformerHead)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
PegasusModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
GSimplifiedEncoderDecoderTransformer
(GEncoderDecoderTransformerF 'Pegasus transformerHead numLayers numLayers gradient device PegasusDataType headDim headEmbedDim embedDim inputEmbedDim ffnDim PegasusPosEncDim vocabDim hasDropout)
MkAbsPos
MkAbsPos
MkTransformerPaddingMask
(MkTransformerAttentionMask PegasusDataType)
(MkTransformerCrossAttentionMask PegasusDataType)
(MkTransformerDecoderAttentionMask PegasusDataType)
pegasusModelSpec ::
forall transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout.
( SingI headDim,
SingI headEmbedDim,
SingI embedDim,
SingI inputEmbedDim,
SingI ffnDim,
SingI vocabDim
) =>
STransformerHead transformerHead ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (PegasusModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout)
pegasusModelSpec :: forall (transformerHead :: TransformerHead) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(PegasusModelF
transformerHead
numLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
hasDropout)
pegasusModelSpec STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
( forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
(numDecoderLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(posEncDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GEncoderDecoderTransformerF
style
transformerHead
numEncoderLayers
numDecoderLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
vocabDim
hasDropout)
encoderDecoderTransformerSpec
STransformerStyle 'Pegasus
SPegasus
STransformerHead transformerHead
transformerHead
SNat numLayers
numLayers
SNat numLayers
numLayers
SGradient gradient
gradient
SDevice device
device
SDataType PegasusDataType
pegasusDataType
(forall {k} (a :: k). SingI a => Sing a
sing @headDim)
(forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
SDim PegasusPosEncDim
pegasusPosEncDim
(forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
SHasDropout hasDropout
hasDropout
Double
pegasusDropoutP
Double
pegasusEps
)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
pegasusBOSTokenId)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
MkAbsPos
MkAbsPos
MkAbsPos
MkAbsPos
(Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
pegasusPadTokenId)
(forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType PegasusDataType
pegasusDataType Double
pegasusAttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType PegasusDataType
pegasusDataType Double
pegasusAttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType PegasusDataType
pegasusDataType Double
pegasusAttentionMaskBias)
mkPegasusInput ::
forall batchDim seqDim device m output.
( MonadThrow m,
SGetDim batchDim,
SGetDim seqDim,
Catch
( 'Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize
]
<+> 'Shape '[batchDim, seqDim]
),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])
) =>
SDim batchDim ->
SDim seqDim ->
SDevice device ->
[[Int]] ->
m output
mkPegasusInput :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkPegasusInput = forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
pegasusPadTokenId