{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Transformer.BERT.Common where
import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderOnly (GEncoderOnlyTransformerF, GSimplifiedEncoderOnlyTransformer (..), encoderOnlyTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkAbsPos (..), MkTransformerAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead, STransformerStyle (SBERT), TransformerHead (..), TransformerStyle (BERT), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))
type BERTDType = 'Float
bertDType :: SDType BERTDType
bertDType :: SDType BERTDType
bertDType = forall {k} (a :: k). SingI a => Sing a
sing @BERTDType
type BERTDataType = 'DataType BERTDType
bertDataType :: SDataType BERTDataType
bertDataType :: SDataType BERTDataType
bertDataType = forall {k} (a :: k). SingI a => Sing a
sing @BERTDataType
bertDropoutP :: Double
bertDropoutP :: Double
bertDropoutP = Double
0.1
type BERTPosEncDim = 'Dim ('Name "*") ('Size 512)
bertPosEncDim :: SDim BERTPosEncDim
bertPosEncDim :: SDim BERTPosEncDim
bertPosEncDim = forall {k} (a :: k). SingI a => Sing a
sing @BERTPosEncDim
bertEps :: Double
bertEps :: Double
bertEps = Double
1e-12
bertMaxPositionEmbeddings :: Int
bertMaxPositionEmbeddings :: Int
bertMaxPositionEmbeddings = Int
512
bertPadTokenId :: Int
bertPadTokenId :: Int
bertPadTokenId = Int
0
bertAttentionMaskBias :: Double
bertAttentionMaskBias :: Double
bertAttentionMaskBias = -Double
10000
type family
BERTModelF
(transformerHead :: TransformerHead)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
(typeVocabDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
BERTModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim typeVocabDim hasDropout =
GSimplifiedEncoderOnlyTransformer
(GEncoderOnlyTransformerF 'BERT transformerHead numLayers gradient device BERTDataType headDim headEmbedDim embedDim inputEmbedDim ffnDim BERTPosEncDim vocabDim typeVocabDim hasDropout)
MkAbsPos
MkTransformerPaddingMask
(MkTransformerAttentionMask BERTDataType)
bertModelSpec ::
forall transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim typeVocabDim hasDropout.
( SingI headDim,
SingI headEmbedDim,
SingI embedDim,
SingI inputEmbedDim,
SingI ffnDim,
SingI vocabDim,
SingI typeVocabDim
) =>
STransformerHead transformerHead ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (BERTModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim typeVocabDim hasDropout)
bertModelSpec :: forall (transformerHead :: TransformerHead) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(typeVocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim,
SingI typeVocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(BERTModelF
transformerHead
numLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
typeVocabDim
hasDropout)
bertModelSpec STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
forall model mkPos mkPaddingMask mkAttentionMask.
model
-> mkPos
-> mkPaddingMask
-> mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
GSimplifiedEncoderOnlyTransformer
( forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(posEncDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(typeVocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SDim typeVocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GEncoderOnlyTransformerF
style
transformerHead
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
vocabDim
typeVocabDim
hasDropout)
encoderOnlyTransformerSpec
STransformerStyle 'BERT
SBERT
STransformerHead transformerHead
transformerHead
SNat numLayers
numLayers
SGradient gradient
gradient
SDevice device
device
SDataType BERTDataType
bertDataType
(forall {k} (a :: k). SingI a => Sing a
sing @headDim)
(forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
SDim BERTPosEncDim
bertPosEncDim
(forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
(forall {k} (a :: k). SingI a => Sing a
sing @typeVocabDim)
SHasDropout hasDropout
hasDropout
Double
bertDropoutP
Double
bertEps
)
MkAbsPos
MkAbsPos
(Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
bertPadTokenId)
(forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType BERTDataType
bertDataType Double
bertAttentionMaskBias)
mkBERTInput ::
forall batchDim seqDim device m output.
( MonadThrow m,
SGetDim batchDim,
SGetDim seqDim,
Catch
( 'Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize
]
<+> 'Shape '[batchDim, seqDim]
),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])
) =>
SDim batchDim ->
SDim seqDim ->
SDevice device ->
[[Int]] ->
m output
mkBERTInput :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkBERTInput = forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
bertPadTokenId