{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Transformer.BART.Common where
import Control.Monad.Catch (MonadThrow)
import Data.Kind (Type)
import Data.Singletons (SingI (..))
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (ModelSpec)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (GEncoderDecoderTransformerF, GSimplifiedEncoderDecoderTransformer (..), encoderDecoderTransformerSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (MkAbsPos (..), MkTransformerAttentionMask (..), MkTransformerCrossAttentionMask (..), MkTransformerDecoderAttentionMask (..), MkTransformerPaddingMask (..), STransformerHead, STransformerStyle (SBART), ShiftRight (..), TransformerHead (..), TransformerStyle (BART), mkTransformerInput)
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (SGetDim, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))
type BARTDType = 'Float
bartDType :: SDType BARTDType
bartDType :: SDType BARTDType
bartDType = forall {k} (a :: k). SingI a => Sing a
sing @BARTDType
type BARTDataType = 'DataType BARTDType
bartDataType :: SDataType BARTDataType
bartDataType :: SDataType BARTDataType
bartDataType = forall {k} (a :: k). SingI a => Sing a
sing @BARTDataType
bartDropoutP :: Double
bartDropoutP :: Double
bartDropoutP = Double
0.1
type BARTPosEncDim = 'Dim ('Name "*") ('Size 1026)
bartPosEncDim :: SDim BARTPosEncDim
bartPosEncDim :: SDim BARTPosEncDim
bartPosEncDim = forall {k} (a :: k). SingI a => Sing a
sing @BARTPosEncDim
bartEps :: Double
bartEps :: Double
bartEps = Double
1e-5
bartMaxPositionEmbeddings :: Int
bartMaxPositionEmbeddings :: Int
bartMaxPositionEmbeddings = Int
1024
bartPadTokenId :: Int
bartPadTokenId :: Int
bartPadTokenId = Int
1
bartBOSTokenId :: Int
bartBOSTokenId :: Int
bartBOSTokenId = Int
0
bartEOSTokenId :: Int
bartEOSTokenId :: Int
bartEOSTokenId = Int
2
bartAttentionMaskBias :: Double
bartAttentionMaskBias :: Double
bartAttentionMaskBias = -Double
10000
type family
BARTModelF
(transformerHead :: TransformerHead)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
BARTModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout =
GSimplifiedEncoderDecoderTransformer
(GEncoderDecoderTransformerF 'BART transformerHead numLayers numLayers gradient device BARTDataType headDim headEmbedDim embedDim inputEmbedDim ffnDim BARTPosEncDim vocabDim hasDropout)
MkAbsPos
MkAbsPos
MkTransformerPaddingMask
(MkTransformerAttentionMask BARTDataType)
(MkTransformerCrossAttentionMask BARTDataType)
(MkTransformerDecoderAttentionMask BARTDataType)
bartModelSpec ::
forall transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout.
( SingI headDim,
SingI headEmbedDim,
SingI embedDim,
SingI inputEmbedDim,
SingI ffnDim,
SingI vocabDim
) =>
STransformerHead transformerHead ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SHasDropout hasDropout ->
ModelSpec (BARTModelF transformerHead numLayers gradient device headDim headEmbedDim embedDim inputEmbedDim ffnDim vocabDim hasDropout)
bartModelSpec :: forall (transformerHead :: TransformerHead) (numLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
(SingI headDim, SingI headEmbedDim, SingI embedDim,
SingI inputEmbedDim, SingI ffnDim, SingI vocabDim) =>
STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec
(BARTModelF
transformerHead
numLayers
gradient
device
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
vocabDim
hasDropout)
bartModelSpec STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SHasDropout hasDropout
hasDropout =
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
model
-> ShiftRight Int
-> ShiftRight Int
-> mkPos
-> mkDecoderPos
-> mkPaddingMask
-> mkAttentionMask
-> mkCrossAttentionMask
-> mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
GSimplifiedEncoderDecoderTransformer
( forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numEncoderLayers :: Natural)
(numDecoderLayers :: Natural)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Natural))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Natural))
(headEmbedDim :: Dim (Name Symbol) (Size Natural))
(embedDim :: Dim (Name Symbol) (Size Natural))
(inputEmbedDim :: Dim (Name Symbol) (Size Natural))
(ffnDim :: Dim (Name Symbol) (Size Natural))
(posEncDim :: Dim (Name Symbol) (Size Natural))
(vocabDim :: Dim (Name Symbol) (Size Natural))
(hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GEncoderDecoderTransformerF
style
transformerHead
numEncoderLayers
numDecoderLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
vocabDim
hasDropout)
encoderDecoderTransformerSpec
STransformerStyle 'BART
SBART
STransformerHead transformerHead
transformerHead
SNat numLayers
numLayers
SNat numLayers
numLayers
SGradient gradient
gradient
SDevice device
device
SDataType BARTDataType
bartDataType
(forall {k} (a :: k). SingI a => Sing a
sing @headDim)
(forall {k} (a :: k). SingI a => Sing a
sing @headEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @embedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @inputEmbedDim)
(forall {k} (a :: k). SingI a => Sing a
sing @ffnDim)
SDim BARTPosEncDim
bartPosEncDim
(forall {k} (a :: k). SingI a => Sing a
sing @vocabDim)
SHasDropout hasDropout
hasDropout
Double
bartDropoutP
Double
bartEps
)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
bartEOSTokenId)
(forall fillValue. fillValue -> ShiftRight fillValue
ShiftRight Int
0)
(Int -> MkAbsPos
MkAbsPosWithOffset Int
2)
(Int -> MkAbsPos
MkAbsPosWithOffset Int
2)
(Int -> MkTransformerPaddingMask
MkTransformerPaddingMask Int
bartPadTokenId)
(forall (dataType :: DataType DType).
SDataType dataType -> Double -> MkTransformerAttentionMask dataType
MkTransformerAttentionMask SDataType BARTDataType
bartDataType Double
bartAttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerCrossAttentionMask dataType
MkTransformerCrossAttentionMask SDataType BARTDataType
bartDataType Double
bartAttentionMaskBias)
(forall (dataType :: DataType DType).
SDataType dataType
-> Double -> MkTransformerDecoderAttentionMask dataType
MkTransformerDecoderAttentionMask SDataType BARTDataType
bartDataType Double
bartAttentionMaskBias)
mkBARTInput ::
forall batchDim seqDim device m output.
( MonadThrow m,
SGetDim batchDim,
SGetDim seqDim,
Catch
( 'Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize
]
<+> 'Shape '[batchDim, seqDim]
),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])
) =>
SDim batchDim ->
SDim seqDim ->
SDevice device ->
[[Int]] ->
m output
mkBARTInput :: forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkBARTInput = forall (batchDim :: Dim (Name Symbol) (Size Natural))
(seqDim :: Dim (Name Symbol) (Size Natural))
(device :: Device (DeviceType Natural)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
bartPadTokenId