{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.RoBERTa.Common where

import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderOnly (GEncoderOnlyTransformerF, GSimplifiedEncoderOnlyTransformer (..), encoderOnlyTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkAbsPos (..), MkTransformerAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead, STransformerStyle (SRoBERTa), TransformerHead (..), TransformerStyle (RoBERTa), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))

-- | RoBERTa dType.
type RoBERTaDType = 'Float

-- | RoBERTa dType singleton.
robertaDType :: SDType RoBERTaDType
robertaDType :: SDType RoBERTaDType
robertaDType = forall {k} (a :: k). SingI a => Sing a
sing @RoBERTaDType

-- | RoBERTa data type.
type RoBERTaDataType = 'DataType RoBERTaDType

-- | RoBERTa data type singleton.
robertaDataType :: SDataType RoBERTaDataType
robertaDataType :: SDataType RoBERTaDataType
robertaDataType = forall {k} (a :: k). SingI a => Sing a
sing @RoBERTaDataType

-- | RoBERTa dropout rate.
-- 'dropout_rate = 0.1'
robertaDropoutP :: Double
robertaDropoutP :: Double
robertaDropoutP = Double
0.1

-- | RoBERTa positional encoding dimension.
--
-- Note the two extra dimensions.
type RoBERTaPosEncDim = 'Dim ('Name "*") ('Size 514)

-- | RoBERTa positional encoding dimension singleton.
robertaPosEncDim :: SDim RoBERTaPosEncDim
robertaPosEncDim :: SDim RoBERTaPosEncDim
robertaPosEncDim = forall {k} (a :: k). SingI a => Sing a
sing @RoBERTaPosEncDim

-- | RoBERTa layer-norm epsilon.
-- 'layer_norm_epsilon = 1e-5'
robertaEps :: Double
robertaEps :: Double
robertaEps = Double
1e-5

-- | RoBERTa maximum number of position embeddings.
-- 'max_position_embeddings = 514'
robertaMaxPositionEmbeddings :: Int
robertaMaxPositionEmbeddings :: Int
robertaMaxPositionEmbeddings = Int
514

-- | RoBERTa padding token id.
-- 'pad_token_id = 1'
robertaPadTokenId :: Int
robertaPadTokenId :: Int
robertaPadTokenId = Int
1

-- | RoBERTa begin-of-sentence token id.
-- 'bos_token_id = 0'
robertaBOSTokenId :: Int
robertaBOSTokenId :: Int
robertaBOSTokenId = Int
0

-- | RoBERTa end-of-sentence token id.
-- 'eos_token_id = 0'
robertaEOSTokenId :: Int
robertaEOSTokenId :: Int
robertaEOSTokenId = Int
2

-- | RoBERTa attention mask bias
robertaAttentionMaskBias :: Double
robertaAttentionMaskBias :: Double
robertaAttentionMaskBias = -Double
10000

-- | Specifies the RoBERTa model.
type family
  RoBERTaModelF
    (transformerHead :: TransformerHead)
    (numLayers :: Nat)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (headDim :: Dim (Name Symbol) (Size Nat))
    (headEmbedDim :: Dim (Name Symbol) (Size Nat))
    (embedDim :: Dim (Name Symbol) (Size Nat))
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (ffnDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat))
    (typeVocabDim :: Dim (Name Symbol) (Size Nat))
    (hasDropout :: HasDropout) ::
    Type
  where
  RoBERTaModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim typeVocabDim hasDropout =
    GSimplifiedEncoderOnlyTransformer
      (GEncoderOnlyTransformerF 'RoBERTa transformerHead numLayers gradient device RoBERTaDataType headDim headEmbedDim embedDim inputEmbedDim ffnDim RoBERTaPosEncDim vocabDim typeVocabDim hasDropout)
      MkAbsPos
      MkTransformerPaddingMask
      (MkTransformerAttentionMask RoBERTaDataType)

-- | Specifies the parameters of a RoBERTa model.
--
-- - @transformerHead@: the head of the RoBERTa model.
-- - @numLayers@: the number of layers in the RoBERTa model.
-- - @gradient@: whether to compute the gradient of the RoBERTa model.
-- - @device@: the computational device on which the RoBERTa model parameters are to be allocated.
robertaModelSpec ::
  forall transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim typeVocabDim hasDropout.
  ( SingI headDim,
    SingI headEmbedDim,
    SingI embedDim,
    SingI inputEmbedDim,
    SingI ffnDim,
    SingI vocabDim,
    SingI typeVocabDim
  ) =>
  STransformerHead transformerHead ->
  SNat numLayers ->
  SGradient gradient ->
  SDevice device ->
  SHasDropout hasDropout ->
  ModelSpec (RoBERTaModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim typeVocabDim hasDropout)
robertaModelSpec :: forall (transformerHead :: TransformerHead) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (typeVocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
 SingI inputEmbedDim, SingI ffnDim, SingI vocabDim,
 SingI typeVocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
     (RoBERTaModelF
        transformerHead
        numLayers
        gradient
        device
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        vocabDim
        typeVocabDim
        hasDropout)
robertaModelSpec STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
  forall model mkPos mkPaddingMask mkAttentionMask.
model
-> mkPos
-> mkPaddingMask
-> mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
     model mkPos mkPaddingMask mkAttentionMask
GSimplifiedEncoderOnlyTransformer
    ( forall (style :: TransformerStyle)
       (transformerHead :: TransformerHead) (numLayers :: Natural)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Natural))
       (dataType :: DataType DType)
       (headDim :: Dim (Name Symbol) (Size Natural))
       (headEmbedDim :: Dim (Name Symbol) (Size Natural))
       (embedDim :: Dim (Name Symbol) (Size Natural))
       (inputEmbedDim :: Dim (Name Symbol) (Size Natural))
       (ffnDim :: Dim (Name Symbol) (Size Natural))
       (posEncDim :: Dim (Name Symbol) (Size Natural))
       (vocabDim :: Dim (Name Symbol) (Size Natural))
       (typeVocabDim :: Dim (Name Symbol) (Size Natural))
       (hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SDim typeVocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
     (GEncoderOnlyTransformerF
        style
        transformerHead
        numLayers
        gradient
        device
        dataType
        headDim
        headEmbedDim
        embedDim
        inputEmbedDim
        ffnDim
        posEncDim
        vocabDim
        typeVocabDim
        hasDropout)
encoderOnlyTransformerSpec
        STransformerStyle 'RoBERTa
SRoBERTa
        STransformerHead transformerHead
transformerHead
        SNat numLayers
numLayers
        SGradient gradient
gradient
        SDevice device
device
        SDataType RoBERTaDataType
robertaDataType
        (forall {k} (a :: k). SingI a => Sing a
sing @headDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
        SDim RoBERTaPosEncDim
robertaPosEncDim
        (forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
        (forall {k} (a :: k). SingI a => Sing a
sing @typeVocabDim)
        SHasDropout hasDropout
hasDropout
        Double
robertaDropoutP
        Double
robertaEps
    )
    (Int -> MkAbsPos
MkAbsPosWithOffset Int
2)
    (Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
robertaPadTokenId)
    (forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType RoBERTaDataType
robertaDataType Double
robertaAttentionMaskBias)

mkRoBERTaInput ::
  forall batchDim seqDim device m output.
  ( MonadThrow m,
    SGetDim batchDim,
    SGetDim seqDim,
    Catch
      ( 'Shape
          '[ 'Dim ('Name "*") 'UncheckedSize,
             'Dim ('Name "*") 'UncheckedSize
           ]
          <+> 'Shape '[batchDim, seqDim]
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          device
          ('DataType 'Int64)
          ('Shape '[batchDim, seqDim])
  ) =>
  SDim batchDim ->
  SDim seqDim ->
  SDevice device ->
  [[Int]] ->
  m output
mkRoBERTaInput :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkRoBERTaInput = forall (batchDim :: Dim (Name Symbol) (Size Natural))
       (seqDim :: Dim (Name Symbol) (Size Natural))
       (device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
robertaPadTokenId