{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Torch.GraduallyTyped.NN.Transformer.GEncoderOnly where
import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (fromSing))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType)
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Sparse (Embedding (..), EmbeddingSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GLMHead (GLMHeadF, lmHeadSpec)
import Torch.GraduallyTyped.NN.Transformer.GTransformer (TransformerEncoderF, transformerEncoderSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (..), STransformerStyle (..), TransformerHead (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked)
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (SNothing))
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient)
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, Size (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add, mulScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Prelude hiding (head)
data EncoderOnlyTransformerHasEmbedScaling
= EncoderOnlyTransformerWithEmbedScaling
| EncoderOnlyTransformerWithoutEmbedScaling
deriving (EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c/= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
== :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c== :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
Eq, Eq EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Ordering
EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
$cmin :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
max :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
$cmax :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling
>= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c>= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
> :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c> :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
<= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c<= :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
< :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
$c< :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Bool
compare :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Ordering
$ccompare :: EncoderOnlyTransformerHasEmbedScaling
-> EncoderOnlyTransformerHasEmbedScaling -> Ordering
Ord, Int -> EncoderOnlyTransformerHasEmbedScaling -> ShowS
[EncoderOnlyTransformerHasEmbedScaling] -> ShowS
EncoderOnlyTransformerHasEmbedScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderOnlyTransformerHasEmbedScaling] -> ShowS
$cshowList :: [EncoderOnlyTransformerHasEmbedScaling] -> ShowS
show :: EncoderOnlyTransformerHasEmbedScaling -> String
$cshow :: EncoderOnlyTransformerHasEmbedScaling -> String
showsPrec :: Int -> EncoderOnlyTransformerHasEmbedScaling -> ShowS
$cshowsPrec :: Int -> EncoderOnlyTransformerHasEmbedScaling -> ShowS
Show, forall x.
Rep EncoderOnlyTransformerHasEmbedScaling x
-> EncoderOnlyTransformerHasEmbedScaling
forall x.
EncoderOnlyTransformerHasEmbedScaling
-> Rep EncoderOnlyTransformerHasEmbedScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep EncoderOnlyTransformerHasEmbedScaling x
-> EncoderOnlyTransformerHasEmbedScaling
$cfrom :: forall x.
EncoderOnlyTransformerHasEmbedScaling
-> Rep EncoderOnlyTransformerHasEmbedScaling x
Generic)
type instance ModelSpec EncoderOnlyTransformerHasEmbedScaling = EncoderOnlyTransformerHasEmbedScaling
instance HasInitialize EncoderOnlyTransformerHasEmbedScaling generatorDevice EncoderOnlyTransformerHasEmbedScaling generatorDevice where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec EncoderOnlyTransformerHasEmbedScaling
-> Generator generatorDevice
-> m (EncoderOnlyTransformerHasEmbedScaling,
Generator generatorDevice)
initialize ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling, Generator generatorDevice
g)
instance HasStateDict EncoderOnlyTransformerHasEmbedScaling where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec EncoderOnlyTransformerHasEmbedScaling
-> StateDictKey -> m EncoderOnlyTransformerHasEmbedScaling
fromStateDict ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec EncoderOnlyTransformerHasEmbedScaling
hasEmbedScaling
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> EncoderOnlyTransformerHasEmbedScaling -> m ()
toStateDict StateDictKey
_ EncoderOnlyTransformerHasEmbedScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
data
GEncoderOnlyTransformer
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(encoder :: Type)
(encoderEmbedding :: Type)
(encoderTypeEmbedding :: Type)
(head :: Type)
where
GEncoderOnlyTransformer ::
forall inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head.
{
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> SDim inputEmbedDim
eotInputEmbedDim :: SDim inputEmbedDim,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoder
eotEncoder :: encoder,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderEmbedding
eotEmbedding :: encoderEmbedding,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderTypeEmbedding
eotTypeEmbedding :: encoderTypeEmbedding,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> head
eotHead :: head,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling :: EncoderOnlyTransformerHasEmbedScaling
} ->
GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
deriving stock (Int
-> GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
Show head) =>
Int
-> GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
Show head) =>
[GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head]
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
Show head) =>
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> String
showList :: [GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head]
-> ShowS
$cshowList :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
Show head) =>
[GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head]
-> ShowS
show :: GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> String
$cshow :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
Show head) =>
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> String
showsPrec :: Int
-> GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
$cshowsPrec :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
(Show encoder, Show encoderEmbedding, Show encoderTypeEmbedding,
Show head) =>
Int
-> GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head x.
Rep
(GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
x
-> GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head x.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> Rep
(GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
x
$cto :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head x.
Rep
(GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
x
-> GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
$cfrom :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head x.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> Rep
(GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
x
Generic)
type instance
ModelSpec (GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head) =
GEncoderOnlyTransformer inputEmbedDim (ModelSpec encoder) (ModelSpec encoderEmbedding) (ModelSpec encoderTypeEmbedding) (ModelSpec head)
type family
GEncoderOnlyTransformerF
(style :: TransformerStyle)
(transformerHead :: TransformerHead)
(numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
(typeVocabDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
GEncoderOnlyTransformerF style transformerHead numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim typeVocabDim hasDropout =
GEncoderOnlyTransformer
inputEmbedDim
(NamedModel (TransformerEncoderF style numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout))
(EOTEmbeddingF style gradient device dataType inputEmbedDim vocabDim)
(EOTTypeEmbeddingF style gradient device dataType inputEmbedDim typeVocabDim)
(EOTHeadF style transformerHead gradient device dataType inputEmbedDim vocabDim)
type family
EOTEmbeddingF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
EOTEmbeddingF _ gradient device dataType inputEmbedDim vocabDim =
NamedModel (Embedding gradient ('Layout 'Dense) device dataType vocabDim inputEmbedDim 'Nothing)
type family
EOTTypeEmbeddingF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(typeVocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
EOTTypeEmbeddingF 'BERT gradient device dataType inputEmbedDim typeVocabDim =
NamedModel (Embedding gradient ('Layout 'Dense) device dataType typeVocabDim inputEmbedDim 'Nothing)
EOTTypeEmbeddingF 'RoBERTa gradient device dataType inputEmbedDim typeVocabDim =
EOTTypeEmbeddingF 'BERT gradient device dataType inputEmbedDim typeVocabDim
type family
EOTHeadF
(style :: TransformerStyle)
(transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
EOTHeadF _ 'WithoutHead _ _ _ _ _ =
()
EOTHeadF style 'WithLMHead gradient device dataType inputEmbedDim vocabDim =
NamedModel (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
encoderOnlyTransformerSpec ::
forall style transformerHead numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim typeVocabDim hasDropout.
STransformerStyle style ->
STransformerHead transformerHead ->
SNat numLayers ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim inputEmbedDim ->
SDim ffnDim ->
SDim posEncDim ->
SDim vocabDim ->
SDim typeVocabDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (GEncoderOnlyTransformerF style transformerHead numLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim typeVocabDim hasDropout)
encoderOnlyTransformerSpec :: forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
(typeVocabDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SDim typeVocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GEncoderOnlyTransformerF
style
transformerHead
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
vocabDim
typeVocabDim
hasDropout)
encoderOnlyTransformerSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SDim vocabDim
vocabDim SDim typeVocabDim
typeVocabDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let encoderSpec :: STransformerStyle style
-> NamedModel
(GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
encoderSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
encoderSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
encoderSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
encoderSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
encoderSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"bert." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'BERT
SBERT
encoderSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"roberta." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'RoBERTa
SRoBERTa
encoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
embeddingSpec :: STransformerStyle style
-> NamedModel
(EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing)
embeddingSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
embeddingSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
embeddingSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
embeddingSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
embeddingSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
embeddingSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"bert.embeddings.word_embeddings." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
embeddingSpec'
embeddingSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"roberta.embeddings.word_embeddings." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
embeddingSpec'
embeddingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
typeEmbeddingSpec :: STransformerStyle style
-> ModelSpec
(EOTTypeEmbeddingF
style gradient device dataType inputEmbedDim typeVocabDim)
typeEmbeddingSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
typeEmbeddingSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
typeEmbeddingSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
typeEmbeddingSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
typeEmbeddingSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
typeEmbeddingSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"bert.embeddings.token_type_embeddings." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
typeVocabDim
inputEmbedDim
'Nothing
typeEmbeddingSpec'
typeEmbeddingSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"roberta.embeddings.token_type_embeddings." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
typeVocabDim
inputEmbedDim
'Nothing
typeEmbeddingSpec'
typeEmbeddingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
headSpec :: STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
(EOTHeadF
style
transformerHead
gradient
device
dataType
inputEmbedDim
vocabDim)
headSpec STransformerStyle style
ST5 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
headSpec STransformerStyle style
SByT5 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
headSpec STransformerStyle style
SBART STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
headSpec STransformerStyle style
SMBART STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
headSpec STransformerStyle style
SPegasus STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
headSpec STransformerStyle style
SBERT STransformerHead transformerHead
SWithoutHead = ()
headSpec STransformerStyle style
SBERT STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"cls.predictions." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'BERT
SBERT
headSpec STransformerStyle style
SRoBERTa STransformerHead transformerHead
SWithoutHead = ()
headSpec STransformerStyle style
SRoBERTa STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'RoBERTa
SRoBERTa
headSpec STransformerStyle style
SGPT2 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
embedScalingSpec :: STransformerStyle style -> EncoderOnlyTransformerHasEmbedScaling
embedScalingSpec :: STransformerStyle style -> EncoderOnlyTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
ST5 = forall a. HasCallStack => a
undefined
embedScalingSpec STransformerStyle style
SByT5 = forall a. HasCallStack => a
undefined
embedScalingSpec STransformerStyle style
SBART = forall a. HasCallStack => a
undefined
embedScalingSpec STransformerStyle style
SMBART = forall a. HasCallStack => a
undefined
embedScalingSpec STransformerStyle style
SPegasus = forall a. HasCallStack => a
undefined
embedScalingSpec STransformerStyle style
SBERT = EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling
embedScalingSpec STransformerStyle style
SRoBERTa = EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling
embedScalingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
SDim inputEmbedDim
-> encoder
-> encoderEmbedding
-> encoderTypeEmbedding
-> head
-> EncoderOnlyTransformerHasEmbedScaling
-> GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
GEncoderOnlyTransformer
SDim inputEmbedDim
inputEmbedDim
(STransformerStyle style
-> NamedModel
(GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
style)
(STransformerStyle style
-> NamedModel
(EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing)
embeddingSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(EOTTypeEmbeddingF
style gradient device dataType inputEmbedDim typeVocabDim)
typeEmbeddingSpec STransformerStyle style
style)
(STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
(EOTHeadF
style
transformerHead
gradient
device
dataType
inputEmbedDim
vocabDim)
headSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead)
(STransformerStyle style -> EncoderOnlyTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
style)
where
encoderSpec' :: _
encoderSpec' :: STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(TransformerEncoderF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
hasDropout)
transformerEncoderSpec STransformerStyle style
style' SNat numLayers
numLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
embeddingSpec' :: EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
embeddingSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim vocabDim
vocabDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing
typeEmbeddingSpec' :: EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
typeVocabDim
inputEmbedDim
'Nothing
typeEmbeddingSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim typeVocabDim
typeVocabDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing
headSpec' :: _
headSpec' :: STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputEmbedDim
-> SDim vocabDim
-> Double
-> ModelSpec
(GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim Double
eps
instance
( HasInitialize encoder generatorDevice encoder' generatorDevice0,
HasInitialize encoderEmbedding generatorDevice0 encoderEmbedding' generatorDevice1,
HasInitialize encoderTypeEmbedding generatorDevice1 encoderTypeEmbedding' generatorDevice2,
HasInitialize head generatorDevice2 head' generatorOutputDevice
) =>
HasInitialize
(GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
generatorDevice
(GEncoderOnlyTransformer inputEmbedDim encoder' encoderEmbedding' encoderTypeEmbedding' head')
generatorOutputDevice
instance
(HasStateDict encoder, HasStateDict encoderEmbedding, HasStateDict encoderTypeEmbedding, HasStateDict head) =>
HasStateDict (GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
data
GSimplifiedEncoderOnlyTransformer
(model :: Type)
(mkPos :: Type)
(mkPaddingMask :: Type)
(mkAttentionMask :: Type)
where
GSimplifiedEncoderOnlyTransformer ::
forall model mkPos mkPaddingMask mkAttentionMask.
{
forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> model
seotModel :: model,
forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> mkPos
seotMkPos :: mkPos,
forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> mkPaddingMask
seotMkPaddingMask :: mkPaddingMask,
forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> mkAttentionMask
seotMkAttentionMask :: mkAttentionMask
} ->
GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask
deriving stock (GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall model mkPos mkPaddingMask mkAttentionMask.
(Eq model, Eq mkPos, Eq mkPaddingMask, Eq mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
/= :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c/= :: forall model mkPos mkPaddingMask mkAttentionMask.
(Eq model, Eq mkPos, Eq mkPaddingMask, Eq mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
== :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c== :: forall model mkPos mkPaddingMask mkAttentionMask.
(Eq model, Eq mkPos, Eq mkPaddingMask, Eq mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
Eq, GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {model} {mkPos} {mkPaddingMask} {mkAttentionMask}.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
Eq
(GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask)
forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Ordering
forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
min :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
$cmin :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
max :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
$cmax :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
>= :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c>= :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
> :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c> :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
<= :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c<= :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
< :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
$c< :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Bool
compare :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Ordering
$ccompare :: forall model mkPos mkPaddingMask mkAttentionMask.
(Ord model, Ord mkPos, Ord mkPaddingMask, Ord mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Ordering
Ord, Int
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
Show mkAttentionMask) =>
Int
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> ShowS
forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
Show mkAttentionMask) =>
[GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask]
-> ShowS
forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
Show mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> String
showList :: [GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask]
-> ShowS
$cshowList :: forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
Show mkAttentionMask) =>
[GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask]
-> ShowS
show :: GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> String
$cshow :: forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
Show mkAttentionMask) =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> String
showsPrec :: Int
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> ShowS
$cshowsPrec :: forall model mkPos mkPaddingMask mkAttentionMask.
(Show model, Show mkPos, Show mkPaddingMask,
Show mkAttentionMask) =>
Int
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall model mkPos mkPaddingMask mkAttentionMask x.
Rep
(GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask)
x
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
forall model mkPos mkPaddingMask mkAttentionMask x.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Rep
(GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask)
x
$cto :: forall model mkPos mkPaddingMask mkAttentionMask x.
Rep
(GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask)
x
-> GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
$cfrom :: forall model mkPos mkPaddingMask mkAttentionMask x.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> Rep
(GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask)
x
Generic)
type instance
ModelSpec (GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask) =
GSimplifiedEncoderOnlyTransformer (ModelSpec model) (ModelSpec mkPos) (ModelSpec mkPaddingMask) (ModelSpec mkAttentionMask)
instance
( HasInitialize model generatorDevice model' generatorDevice0,
HasInitialize mkPos generatorDevice0 mkPos' generatorDevice1,
HasInitialize mkPaddingMask generatorDevice1 mkPaddingMask' generatorDevice2,
HasInitialize mkAttentionMask generatorDevice2 mkAttentionMask' generatorOutputDevice
) =>
HasInitialize
(GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask)
generatorDevice
(GSimplifiedEncoderOnlyTransformer model' mkPos' mkPaddingMask' mkAttentionMask')
generatorOutputDevice
instance
(HasStateDict model, HasStateDict mkPos, HasStateDict mkPaddingMask, HasStateDict mkAttentionMask) =>
HasStateDict (GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask)
data EncoderOnlyTransformerInput input inputType pos attentionMask where
EncoderOnlyTransformerInput ::
forall input inputType pos attentionMask.
{ forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> input
eotInput :: input,
forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> inputType
eotInputType :: inputType,
forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> pos
eotPos :: pos,
forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> attentionMask
eotAttentionMask :: attentionMask
} ->
EncoderOnlyTransformerInput input inputType pos attentionMask
deriving stock (EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input inputType pos attentionMask.
(Eq input, Eq inputType, Eq pos, Eq attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
/= :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c/= :: forall input inputType pos attentionMask.
(Eq input, Eq inputType, Eq pos, Eq attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
== :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c== :: forall input inputType pos attentionMask.
(Eq input, Eq inputType, Eq pos, Eq attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
Eq, EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {inputType} {pos} {attentionMask}.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
Eq (EncoderOnlyTransformerInput input inputType pos attentionMask)
forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
min :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
$cmin :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
max :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
$cmax :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
>= :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c>= :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
> :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c> :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
<= :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c<= :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
< :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
$c< :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Bool
compare :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
$ccompare :: forall input inputType pos attentionMask.
(Ord input, Ord inputType, Ord pos, Ord attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Ordering
Ord, Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
[EncoderOnlyTransformerInput input inputType pos attentionMask]
-> ShowS
forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> String
showList :: [EncoderOnlyTransformerInput input inputType pos attentionMask]
-> ShowS
$cshowList :: forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
[EncoderOnlyTransformerInput input inputType pos attentionMask]
-> ShowS
show :: EncoderOnlyTransformerInput input inputType pos attentionMask
-> String
$cshow :: forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
EncoderOnlyTransformerInput input inputType pos attentionMask
-> String
showsPrec :: Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
$cshowsPrec :: forall input inputType pos attentionMask.
(Show input, Show inputType, Show pos, Show attentionMask) =>
Int
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input inputType pos attentionMask x.
Rep
(EncoderOnlyTransformerInput input inputType pos attentionMask) x
-> EncoderOnlyTransformerInput input inputType pos attentionMask
forall input inputType pos attentionMask x.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> Rep
(EncoderOnlyTransformerInput input inputType pos attentionMask) x
$cto :: forall input inputType pos attentionMask x.
Rep
(EncoderOnlyTransformerInput input inputType pos attentionMask) x
-> EncoderOnlyTransformerInput input inputType pos attentionMask
$cfrom :: forall input inputType pos attentionMask x.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> Rep
(EncoderOnlyTransformerInput input inputType pos attentionMask) x
Generic)
data SimplifiedEncoderOnlyTransformerInput input inputType where
SimplifiedEncoderOnlyTransformerInput ::
forall input inputType.
{ forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> input
seotInput :: input,
forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> inputType
seotInputType :: inputType
} ->
SimplifiedEncoderOnlyTransformerInput input inputType
deriving stock (SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input inputType.
(Eq input, Eq inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
/= :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c/= :: forall input inputType.
(Eq input, Eq inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
== :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c== :: forall input inputType.
(Eq input, Eq inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
Eq, SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {inputType}.
(Ord input, Ord inputType) =>
Eq (SimplifiedEncoderOnlyTransformerInput input inputType)
forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
min :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
$cmin :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
max :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
$cmax :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
>= :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c>= :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
> :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c> :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
<= :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c<= :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
< :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
$c< :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType -> Bool
compare :: SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
$ccompare :: forall input inputType.
(Ord input, Ord inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Ordering
Ord, Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input inputType.
(Show input, Show inputType) =>
Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
forall input inputType.
(Show input, Show inputType) =>
[SimplifiedEncoderOnlyTransformerInput input inputType] -> ShowS
forall input inputType.
(Show input, Show inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType -> String
showList :: [SimplifiedEncoderOnlyTransformerInput input inputType] -> ShowS
$cshowList :: forall input inputType.
(Show input, Show inputType) =>
[SimplifiedEncoderOnlyTransformerInput input inputType] -> ShowS
show :: SimplifiedEncoderOnlyTransformerInput input inputType -> String
$cshow :: forall input inputType.
(Show input, Show inputType) =>
SimplifiedEncoderOnlyTransformerInput input inputType -> String
showsPrec :: Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
$cshowsPrec :: forall input inputType.
(Show input, Show inputType) =>
Int
-> SimplifiedEncoderOnlyTransformerInput input inputType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input inputType x.
Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
-> SimplifiedEncoderOnlyTransformerInput input inputType
forall input inputType x.
SimplifiedEncoderOnlyTransformerInput input inputType
-> Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
$cto :: forall input inputType x.
Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
-> SimplifiedEncoderOnlyTransformerInput input inputType
$cfrom :: forall input inputType x.
SimplifiedEncoderOnlyTransformerInput input inputType
-> Rep (SimplifiedEncoderOnlyTransformerInput input inputType) x
Generic)
data EncoderOnlyTransformerOutput output where
EncoderOnlyTransformerOutput ::
forall output.
{ forall output. EncoderOnlyTransformerOutput output -> output
eotOutput :: output
} ->
EncoderOnlyTransformerOutput output
deriving stock (EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
forall output.
Eq output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c/= :: forall output.
Eq output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
== :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c== :: forall output.
Eq output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
Eq, EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {output}.
Ord output =>
Eq (EncoderOnlyTransformerOutput output)
forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
min :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
$cmin :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
max :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
$cmax :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output
>= :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c>= :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
> :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c> :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
<= :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c<= :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
< :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
$c< :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Bool
compare :: EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
$ccompare :: forall output.
Ord output =>
EncoderOnlyTransformerOutput output
-> EncoderOnlyTransformerOutput output -> Ordering
Ord, Int -> EncoderOnlyTransformerOutput output -> ShowS
forall output.
Show output =>
Int -> EncoderOnlyTransformerOutput output -> ShowS
forall output.
Show output =>
[EncoderOnlyTransformerOutput output] -> ShowS
forall output.
Show output =>
EncoderOnlyTransformerOutput output -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderOnlyTransformerOutput output] -> ShowS
$cshowList :: forall output.
Show output =>
[EncoderOnlyTransformerOutput output] -> ShowS
show :: EncoderOnlyTransformerOutput output -> String
$cshow :: forall output.
Show output =>
EncoderOnlyTransformerOutput output -> String
showsPrec :: Int -> EncoderOnlyTransformerOutput output -> ShowS
$cshowsPrec :: forall output.
Show output =>
Int -> EncoderOnlyTransformerOutput output -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall output x.
Rep (EncoderOnlyTransformerOutput output) x
-> EncoderOnlyTransformerOutput output
forall output x.
EncoderOnlyTransformerOutput output
-> Rep (EncoderOnlyTransformerOutput output) x
$cto :: forall output x.
Rep (EncoderOnlyTransformerOutput output) x
-> EncoderOnlyTransformerOutput output
$cfrom :: forall output x.
EncoderOnlyTransformerOutput output
-> Rep (EncoderOnlyTransformerOutput output) x
Generic)
data SimplifiedEncoderOnlyTransformerOutput output paddingMask where
SimplifiedEncoderOnlyTransformerOutput ::
forall output paddingMask.
{ forall output paddingMask.
SimplifiedEncoderOnlyTransformerOutput output paddingMask -> output
seotOutput :: output,
forall output paddingMask.
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> paddingMask
sedtPaddingMask :: paddingMask
} ->
SimplifiedEncoderOnlyTransformerOutput output paddingMask
deriving stock (SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall output paddingMask.
(Eq output, Eq paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
/= :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c/= :: forall output paddingMask.
(Eq output, Eq paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
== :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c== :: forall output paddingMask.
(Eq output, Eq paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
Eq, SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {output} {paddingMask}.
(Ord output, Ord paddingMask) =>
Eq (SimplifiedEncoderOnlyTransformerOutput output paddingMask)
forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
min :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
$cmin :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
max :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
$cmax :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
>= :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c>= :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
> :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c> :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
<= :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c<= :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
< :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
$c< :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Bool
compare :: SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
$ccompare :: forall output paddingMask.
(Ord output, Ord paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall output paddingMask.
(Show output, Show paddingMask) =>
Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
forall output paddingMask.
(Show output, Show paddingMask) =>
[SimplifiedEncoderOnlyTransformerOutput output paddingMask]
-> ShowS
forall output paddingMask.
(Show output, Show paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask -> String
showList :: [SimplifiedEncoderOnlyTransformerOutput output paddingMask]
-> ShowS
$cshowList :: forall output paddingMask.
(Show output, Show paddingMask) =>
[SimplifiedEncoderOnlyTransformerOutput output paddingMask]
-> ShowS
show :: SimplifiedEncoderOnlyTransformerOutput output paddingMask -> String
$cshow :: forall output paddingMask.
(Show output, Show paddingMask) =>
SimplifiedEncoderOnlyTransformerOutput output paddingMask -> String
showsPrec :: Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
$cshowsPrec :: forall output paddingMask.
(Show output, Show paddingMask) =>
Int
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall output paddingMask x.
Rep (SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
forall output paddingMask x.
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Rep
(SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
$cto :: forall output paddingMask x.
Rep (SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
$cfrom :: forall output paddingMask x.
SimplifiedEncoderOnlyTransformerOutput output paddingMask
-> Rep
(SimplifiedEncoderOnlyTransformerOutput output paddingMask) x
Generic)
instance
( HasForward
encoderEmbedding
input
generatorDevice
embeddingOutput
embeddingGeneratorOutputDevice,
embeddingOutput ~ Tensor gradient' layout' device' dataType' shape',
HasForward
encoderTypeEmbedding
inputType
embeddingGeneratorOutputDevice
typeEmbeddingOutput
typeEmbeddingGeneratorOutputDevice,
typeEmbeddingOutput ~ Tensor gradient'' layout'' device'' dataType'' shape'',
HasForward
encoder
( Tensor
(gradient' <|> gradient'')
(layout' <+> layout'')
(device' <+> device'')
(dataType' <+> dataType'')
(BroadcastShapesF shape' shape''),
pos,
attentionMask
)
typeEmbeddingGeneratorOutputDevice
encoderOutput
encoderGeneratorOutputDevice,
Catch (BroadcastShapesF shape' shape''),
HasForward
head
encoderOutput
encoderGeneratorOutputDevice
headOutput
generatorOutputDevice
) =>
HasForward
(GEncoderOnlyTransformer inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head)
(EncoderOnlyTransformerInput input inputType pos attentionMask)
generatorDevice
(EncoderOnlyTransformerOutput headOutput)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> EncoderOnlyTransformerInput input inputType pos attentionMask
-> Generator generatorDevice
-> m (EncoderOnlyTransformerOutput headOutput,
Generator generatorOutputDevice)
forward GEncoderOnlyTransformer {encoderEmbedding
encoderTypeEmbedding
encoder
head
SDim inputEmbedDim
EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling :: EncoderOnlyTransformerHasEmbedScaling
eotHead :: head
eotTypeEmbedding :: encoderTypeEmbedding
eotEmbedding :: encoderEmbedding
eotEncoder :: encoder
eotInputEmbedDim :: SDim inputEmbedDim
eotEmbedScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> EncoderOnlyTransformerHasEmbedScaling
eotHead :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> head
eotTypeEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderTypeEmbedding
eotEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoderEmbedding
eotEncoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> encoder
eotInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
encoderEmbedding encoderTypeEmbedding head.
GEncoderOnlyTransformer
inputEmbedDim encoder encoderEmbedding encoderTypeEmbedding head
-> SDim inputEmbedDim
..} EncoderOnlyTransformerInput {input
inputType
pos
attentionMask
eotAttentionMask :: attentionMask
eotPos :: pos
eotInputType :: inputType
eotInput :: input
eotAttentionMask :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> attentionMask
eotPos :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> pos
eotInputType :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> inputType
eotInput :: forall input inputType pos attentionMask.
EncoderOnlyTransformerInput input inputType pos attentionMask
-> input
..} =
let Double
scaling :: Double = forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
eotInputEmbedDim
embeddedInput :: IxStateT
m
(Generator generatorDevice)
(Generator embeddingGeneratorOutputDevice)
(Tensor gradient' layout' device' dataType' shape')
embeddedInput =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn input
eotInput
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoderEmbedding
eotEmbedding
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
)
EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling
embeddedInputType :: IxStateT
m
(Generator embeddingGeneratorOutputDevice)
(Generator typeEmbeddingGeneratorOutputDevice)
(Tensor gradient'' layout'' device'' dataType'' shape'')
embeddedInputType =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn inputType
eotInputType
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoderTypeEmbedding
eotTypeEmbedding
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
EncoderOnlyTransformerHasEmbedScaling
EncoderOnlyTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
)
EncoderOnlyTransformerHasEmbedScaling
eotEmbedScaling
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
(,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator embeddingGeneratorOutputDevice)
(Tensor gradient' layout' device' dataType' shape')
embeddedInput forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator embeddingGeneratorOutputDevice)
(Generator typeEmbeddingGeneratorOutputDevice)
(Tensor gradient'' layout'' device'' dataType'' shape'')
embeddedInputType
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
add
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor
(Or (Gradient RequiresGradient) gradient' gradient'')
(Unify (Layout LayoutType) layout' layout'')
(Unify (Device (DeviceType Nat)) device' device'')
(Unify (DataType DType) dataType' dataType'')
(BroadcastShapesF shape' shape'')
input' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoder
eotEncoder (Tensor
(Or (Gradient RequiresGradient) gradient' gradient'')
(Unify (Layout LayoutType) layout' layout'')
(Unify (Device (DeviceType Nat)) device' device'')
(Unify (DataType DType) dataType' dataType'')
(BroadcastShapesF shape' shape'')
input', pos
eotPos, attentionMask
eotAttentionMask))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward head
eotHead
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall output. output -> EncoderOnlyTransformerOutput output
EncoderOnlyTransformerOutput
instance
( HasForward
mkPaddingMask
input
generatorDevice
paddingMask
generatorDevice,
HasForward
mkAttentionMask
paddingMask
generatorDevice
attentionMask
generatorDevice,
HasForward
mkPos
input
generatorDevice
pos
generatorDevice,
HasForward
model
(EncoderOnlyTransformerInput input inputType pos attentionMask)
generatorDevice
(EncoderOnlyTransformerOutput output)
generatorOutputDevice
) =>
HasForward
(GSimplifiedEncoderOnlyTransformer model mkPos mkPaddingMask mkAttentionMask)
(SimplifiedEncoderOnlyTransformerInput input inputType)
generatorDevice
(SimplifiedEncoderOnlyTransformerOutput output paddingMask)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> SimplifiedEncoderOnlyTransformerInput input inputType
-> Generator generatorDevice
-> m (SimplifiedEncoderOnlyTransformerOutput output paddingMask,
Generator generatorOutputDevice)
forward GSimplifiedEncoderOnlyTransformer {mkPaddingMask
mkAttentionMask
mkPos
model
seotMkAttentionMask :: mkAttentionMask
seotMkPaddingMask :: mkPaddingMask
seotMkPos :: mkPos
seotModel :: model
seotMkAttentionMask :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> mkAttentionMask
seotMkPaddingMask :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> mkPaddingMask
seotMkPos :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> mkPos
seotModel :: forall model mkPos mkPaddingMask mkAttentionMask.
GSimplifiedEncoderOnlyTransformer
model mkPos mkPaddingMask mkAttentionMask
-> model
..} SimplifiedEncoderOnlyTransformerInput {input
inputType
seotInputType :: inputType
seotInput :: input
seotInputType :: forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> inputType
seotInput :: forall input inputType.
SimplifiedEncoderOnlyTransformerInput input inputType -> input
..} =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
( let paddingMask :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
paddingMask
paddingMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPaddingMask
seotMkPaddingMask forall a b. (a -> b) -> a -> b
$ input
seotInput
pos :: IxStateT
m (Generator generatorDevice) (Generator generatorDevice) pos
pos = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPos
seotMkPos forall a b. (a -> b) -> a -> b
$ input
seotInput
in (,) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
paddingMask
paddingMask forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m (Generator generatorDevice) (Generator generatorDevice) pos
pos
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(paddingMask
paddingMask, pos
pos) ->
let attentionMask :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
attentionMask
attentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkAttentionMask
seotMkAttentionMask forall a b. (a -> b) -> a -> b
$ paddingMask
paddingMask
in ( forall input inputType pos attentionMask.
input
-> inputType
-> pos
-> attentionMask
-> EncoderOnlyTransformerInput input inputType pos attentionMask
EncoderOnlyTransformerInput
input
seotInput
inputType
seotInputType
pos
pos
forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
attentionMask
attentionMask
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
seotModel
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(EncoderOnlyTransformerOutput output
output) ->
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$ forall output paddingMask.
output
-> paddingMask
-> SimplifiedEncoderOnlyTransformerOutput output paddingMask
SimplifiedEncoderOnlyTransformerOutput output
output paddingMask
paddingMask
)
)