{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.BART.Common where

import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (GEncoderDecoderTransformerF, GSimplifiedEncoderDecoderTransformer (..), encoderDecoderTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkAbsPos (..), MkTransformerAttentionMask (..), MkTransformerCrossAttentionMask (..), MkTransformerDecoderAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead, STransformerStyle (SBART), ShiftRight (..), TransformerHead (..), TransformerStyle (BART), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))

-- | BART dType.
type BARTDType = 'Float

-- | BART dType singleton.
bartDType :: SDType BARTDType
bartDType :: SDType BARTDType
bartDType = forall {k} (a :: k). SingI a => Sing a
sing @BARTDType

-- | BART data type.
type BARTDataType = 'DataType BARTDType

-- | BART data type singleton.
bartDataType :: SDataType BARTDataType
bartDataType :: SDataType BARTDataType
bartDataType = forall {k} (a :: k). SingI a => Sing a
sing @BARTDataType

-- | BART dropout rate.
-- 'dropout_rate = 0.1'
bartDropoutP :: Double
bartDropoutP :: Double
bartDropoutP = Double
0.1

-- | BART positional encoding dimension.
type BARTPosEncDim = 'Dim ('Name "*") ('Size 1026)

-- | BART positional encoding dimension singleton.
bartPosEncDim :: SDim BARTPosEncDim
bartPosEncDim :: SDim BARTPosEncDim
bartPosEncDim = forall {k} (a :: k). SingI a => Sing a
sing @BARTPosEncDim

-- | BART layer-norm epsilon.
bartEps :: Double
bartEps :: Double
bartEps = Double
1e-5

-- | BART maximum number of position embeddings.
-- 'max_position_embeddings = 1024'
bartMaxPositionEmbeddings :: Int
bartMaxPositionEmbeddings :: Int
bartMaxPositionEmbeddings = Int
1024

-- | BART padding token id.
-- 'pad_token_id = 1'
bartPadTokenId :: Int
bartPadTokenId :: Int
bartPadTokenId = Int
1

-- | BART begin-of-sentence token id.
-- 'bos_token_id = 0'
bartBOSTokenId :: Int
bartBOSTokenId :: Int
bartBOSTokenId = Int
0

-- | BART end-of-sentence token id.
-- 'eos_token_id = 2'
bartEOSTokenId :: Int
bartEOSTokenId :: Int
bartEOSTokenId = Int
2

-- | BART attention mask bias
bartAttentionMaskBias :: Double
bartAttentionMaskBias :: Double
bartAttentionMaskBias = -Double
10000

-- | Specifies the BART model.
type family
  BARTModelF
    (transformerHead :: TransformerHead)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  BARTModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
    GSimplifiedEncoderDecoderTransformer
      (GEncoderDecoderTransformerF 'BART transformerHead numLayers numLayers gradient device BARTDataType headDim headEmbedDim embedDim inputEmbedDim ffnDim BARTPosEncDim vocabDim hasDropout)
      MkAbsPos
      MkAbsPos
      MkTransformerPaddingMask
      (MkTransformerAttentionMask BARTDataType)
      (MkTransformerCrossAttentionMask BARTDataType)
      (MkTransformerDecoderAttentionMask BARTDataType)

-- | Specifies the parameters of a BART model.
--
-- - @transformerHead@: the head of the BART model.
-- - @numLayers@: the number of layers in the BART model.
-- - @gradient@: whether to compute the gradient of the BART model.
-- - @device@: the computational device on which the BART model parameters are to be allocated.
bartModelSpec ::
  forall transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout.
  ( SingI headDim,
    SingI headEmbedDim,
    SingI embedDim,
    SingI inputEmbedDim,
    SingI ffnDim,
    SingI vocabDim
  ) =>
  STransformerHead transformerHead ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SHasDropout hasDropout ->
  ModelSpec (BARTModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout)
bartModelSpec :: forall (transformerHead :: TransformerHead) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
 SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
     (BARTModelF
        transformerHead
        numLayers
        gradient
        device
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        vocabDim
        hasDropout)
bartModelSpec STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
  forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
       mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
     model
     mkPos
     mkDecoderPos
     mkPaddingMask
     mkAttentionMask
     mkCrossAttentionMask
     mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
    ( forall (style :: TransformerStyle)
       (transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
       (numDecoderLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (posEncDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GEncoderDecoderTransformerF
        style
        transformerHead
        numEncoderLayers
        numDecoderLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        vocabDim
        hasDropout)
encoderDecoderTransformerSpec
        STransformerStyle 'BART
SBART
        STransformerHead transformerHead
transformerHead
        SNat numLayers
numLayers
        SNat numLayers
numLayers
        SGradient gradient
gradient
        SDevice device
device
        SDataType BARTDataType
bartDataType
        (forall {k} (a :: k). SingI a => Sing a
sing @headDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
        SDim BARTPosEncDim
bartPosEncDim
        (forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
        SHasDropout hasDropout
hasDropout
        Double
bartDropoutP
        Double
bartEps
    )
    (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
bartEOSTokenId)
    (forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
    (Int -> MkAbsPos
MkAbsPosWithOffset Int
2)
    (Int -> MkAbsPos
MkAbsPosWithOffset Int
2)
    (Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
bartPadTokenId)
    (forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType BARTDataType
bartDataType Double
bartAttentionMaskBias)
    (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType BARTDataType
bartDataType Double
bartAttentionMaskBias)
    (forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType BARTDataType
bartDataType Double
bartAttentionMaskBias)

mkBARTInput ::
  forall batchDim seqDim device m output.
  ( MonadThrow m,
    SGetDim batchDim,
    SGetDim seqDim,
    Catch
      ( 'Shape
          '[ 'Dim ('Name "*") 'UncheckedSize,
             'Dim ('Name "*") 'UncheckedSize
           ]
          <+> 'Shape '[batchDim, seqDim]
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          device
          ('DataType 'Int64)
          ('Shape '[batchDim, seqDim])
  ) =>
  SDim batchDim ->
  SDim seqDim ->
  SDevice device ->
  [[Int]] ->
  m output
mkBARTInput :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkBARTInput = forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
bartPadTokenId