{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.Pegasus.Common where

import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (GEncoderDecoderTransformerF, GSimplifiedEncoderDecoderTransformer (..), encoderDecoderTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkAbsPos (..), MkTransformerAttentionMask (..), MkTransformerCrossAttentionMask (..), MkTransformerDecoderAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead, STransformerStyle (SPegasus), ShiftRight (..), TransformerHead (..), TransformerStyle (Pegasus), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))

-- | Pegasus dType.
type PegasusDType = 'Float

-- | Pegasus dType singleton.
pegasusDType :: SDType PegasusDType
pegasusDType :: SDType PegasusDType
pegasusDType = forall {k} (a :: k). SingI a => Sing a
sing @PegasusDType

-- | Pegasus data type.
type PegasusDataType = 'DataType PegasusDType

-- | Pegasus data type singleton.
pegasusDataType :: SDataType PegasusDataType
pegasusDataType :: SDataType PegasusDataType
pegasusDataType = forall {k} (a :: k). SingI a => Sing a
sing @PegasusDataType

-- | Pegasus dropout rate.
-- 'dropout_rate = 0.1'
pegasusDropoutP :: Double
pegasusDropoutP :: Double
pegasusDropoutP = Double
0.1

-- | Pegasus positional encoding dimension.
type PegasusPosEncDim = 'Dim ('Name "*") ('Size 512)

-- | Pegasus positional encoding dimension singleton.
pegasusPosEncDim :: SDim PegasusPosEncDim
pegasusPosEncDim :: SDim PegasusPosEncDim
pegasusPosEncDim = forall {k} (a :: k). SingI a => Sing a
sing @PegasusPosEncDim

-- | Pegasus layer-norm epsilon.
pegasusEps :: Double
pegasusEps :: Double
pegasusEps = Double
1e-5

-- | Pegasus maximum number of position embeddings.
-- 'max_position_embeddings = 512'
pegasusMaxPositionEmbeddings :: Int
pegasusMaxPositionEmbeddings :: Int
pegasusMaxPositionEmbeddings = Int
512

-- | Pegasus padding token id.
-- 'pad_token_id = 0'
pegasusPadTokenId :: Int
pegasusPadTokenId :: Int
pegasusPadTokenId = Int
0

-- | Pegasus begin-of-sentence token id.
-- 'bos_token_id = 0'
pegasusBOSTokenId :: Int
pegasusBOSTokenId :: Int
pegasusBOSTokenId = Int
pegasusPadTokenId

-- | Pegasus end-of-sentence token id.
-- 'eos_token_id = 0'
pegasusEOSTokenId :: Int
pegasusEOSTokenId :: Int
pegasusEOSTokenId = Int
1

-- | Pegasus attention mask bias
pegasusAttentionMaskBias :: Double
pegasusAttentionMaskBias :: Double
pegasusAttentionMaskBias = -Double
10000

-- | Specifies the Pegasus model.
type family
  PegasusModelF
    (transformerHead :: TransformerHead)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  PegasusModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
    GSimplifiedEncoderDecoderTransformer
      (GEncoderDecoderTransformerF 'Pegasus transformerHead numLayers numLayers gradient device PegasusDataType headDim headEmbedDim embedDim inputEmbedDim ffnDim PegasusPosEncDim vocabDim hasDropout)
      MkAbsPos
      MkAbsPos
      MkTransformerPaddingMask
      (MkTransformerAttentionMask PegasusDataType)
      (MkTransformerCrossAttentionMask PegasusDataType)
      (MkTransformerDecoderAttentionMask PegasusDataType)

-- | Specifies the parameters of a Pegasus model.
--
-- - @transformerHead@: the head of the Pegasus model.
-- - @numLayers@: the number of layers in the Pegasus model.
-- - @gradient@: whether to compute the gradient of the Pegasus model.
-- - @device@: the computational device on which the Pegasus model parameters are to be allocated.
pegasusModelSpec ::
  forall transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout.
  ( SingI headDim,
    SingI headEmbedDim,
    SingI embedDim,
    SingI inputEmbedDim,
    SingI ffnDim,
    SingI vocabDim
  ) =>
  STransformerHead transformerHead ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SHasDropout hasDropout ->
  ModelSpec (PegasusModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout)
pegasusModelSpec :: forall (transformerHead :: TransformerHead) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
 SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
     (PegasusModelF
        transformerHead
        numLayers
        gradient
        device
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        vocabDim
        hasDropout)
pegasusModelSpec STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
  forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
    ( forall (style :: TransformerStyle)
       (transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
       (numDecoderLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (posEncDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GEncoderDecoderTransformerF
        style
        transformerHead
        numEncoderLayers
        numDecoderLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        vocabDim
        hasDropout)
encoderDecoderTransformerSpec
        STransformerStyle 'Pegasus
SPegasus
        STransformerHead transformerHead
transformerHead
        SNat numLayers
numLayers
        SNat numLayers
numLayers
        SGradient gradient
gradient
        SDevice device
device
        SDataType PegasusDataType
pegasusDataType
        (forall {k} (a :: k). SingI a => Sing a
sing @headDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
        SDim PegasusPosEncDim
pegasusPosEncDim
        (forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
        SHasDropout hasDropout
hasDropout
        Double
pegasusDropoutP
        Double
pegasusEps
    )
    (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
pegasusBOSTokenId)
    (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
    MkAbsPos
MkAbsPos
    MkAbsPos
MkAbsPos
    (Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
pegasusPadTokenId)
    (forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType PegasusDataType
pegasusDataType Double
pegasusAttentionMaskBias)
    (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType PegasusDataType
pegasusDataType Double
pegasusAttentionMaskBias)
    (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType PegasusDataType
pegasusDataType Double
pegasusAttentionMaskBias)

mkPegasusInput ::
  forall batchDim seqDim device m output.
  ( MonadThrow m,
    SGetDim batchDim,
    SGetDim seqDim,
    Catch
      ( 'Shape
          '[ 'Dim ('Name "*") 'UncheckedSize,
             'Dim ('Name "*") 'UncheckedSize
           ]
          <+> 'Shape '[batchDim, seqDim]
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          device
          ('DataType 'Int64)
          ('Shape '[batchDim, seqDim])
  ) =>
  SDim batchDim ->
  SDim seqDim ->
  SDevice device ->
  [[Int]] ->
  m output
mkPegasusInput :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkPegasusInput = forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
pegasusPadTokenId