{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder where
import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (fromSing))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Index.Type (Index (NegativeIndex), SIndex (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Functional.NonLinearActivation (SoftmaxF, logSoftmax)
import Torch.GraduallyTyped.NN.Sparse (Embedding (..), EmbeddingSpec (..))
import Torch.GraduallyTyped.NN.Transformer.GLMHead (GLMHeadF, lmHeadSpec)
import Torch.GraduallyTyped.NN.Transformer.GTransformer (TransformerDecoderF, TransformerEncoderF, transformerDecoderSpec, transformerEncoderSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (..), STransformerStyle (..), ShiftRight, TransformerHead (WithLMHead, WithoutHead), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.Prelude.Maybe (SMaybe (SNothing))
import Torch.GraduallyTyped.Prelude.TypeLits (SNat)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim, SSelectDim (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Indexing (IndexDims, IndexType (..), Indices (..), SIndexType (..), SIndices (..), (!))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (GatherDimF, SqueezeDimF, UnsqueezeF, sGatherDim, sSqueezeDim, sUnsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mulScalar)
import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (MeanAllCheckF, meanAll)
import Torch.GraduallyTyped.Tensor.Type (Tensor ())
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Prelude hiding (head)
data EncoderDecoderTransformerHasEmbedScaling
= EncoderDecoderTransformerWithEmbedScaling
| EncoderDecoderTransformerWithoutEmbedScaling
deriving stock (EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c/= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
== :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c== :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
Eq, Eq EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Ordering
EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
$cmin :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
max :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
$cmax :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling
>= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c>= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
> :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c> :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
<= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c<= :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
< :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
$c< :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Bool
compare :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Ordering
$ccompare :: EncoderDecoderTransformerHasEmbedScaling
-> EncoderDecoderTransformerHasEmbedScaling -> Ordering
Ord, Int -> EncoderDecoderTransformerHasEmbedScaling -> ShowS
[EncoderDecoderTransformerHasEmbedScaling] -> ShowS
EncoderDecoderTransformerHasEmbedScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderDecoderTransformerHasEmbedScaling] -> ShowS
$cshowList :: [EncoderDecoderTransformerHasEmbedScaling] -> ShowS
show :: EncoderDecoderTransformerHasEmbedScaling -> String
$cshow :: EncoderDecoderTransformerHasEmbedScaling -> String
showsPrec :: Int -> EncoderDecoderTransformerHasEmbedScaling -> ShowS
$cshowsPrec :: Int -> EncoderDecoderTransformerHasEmbedScaling -> ShowS
Show, forall x.
Rep EncoderDecoderTransformerHasEmbedScaling x
-> EncoderDecoderTransformerHasEmbedScaling
forall x.
EncoderDecoderTransformerHasEmbedScaling
-> Rep EncoderDecoderTransformerHasEmbedScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep EncoderDecoderTransformerHasEmbedScaling x
-> EncoderDecoderTransformerHasEmbedScaling
$cfrom :: forall x.
EncoderDecoderTransformerHasEmbedScaling
-> Rep EncoderDecoderTransformerHasEmbedScaling x
Generic)
type instance ModelSpec EncoderDecoderTransformerHasEmbedScaling = EncoderDecoderTransformerHasEmbedScaling
instance HasInitialize EncoderDecoderTransformerHasEmbedScaling generatorDevice EncoderDecoderTransformerHasEmbedScaling generatorDevice where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec EncoderDecoderTransformerHasEmbedScaling
-> Generator generatorDevice
-> m (EncoderDecoderTransformerHasEmbedScaling,
Generator generatorDevice)
initialize ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling, Generator generatorDevice
g)
instance HasStateDict EncoderDecoderTransformerHasEmbedScaling where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec EncoderDecoderTransformerHasEmbedScaling
-> StateDictKey -> m EncoderDecoderTransformerHasEmbedScaling
fromStateDict ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec EncoderDecoderTransformerHasEmbedScaling
hasEmbedScaling
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> EncoderDecoderTransformerHasEmbedScaling -> m ()
toStateDict StateDictKey
_ EncoderDecoderTransformerHasEmbedScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
data
GEncoderDecoderTransformer
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(encoder :: Type)
(decoder :: Type)
(sharedEmbedding :: Type)
(head :: Type)
where
GEncoderDecoderTransformer ::
forall inputEmbedDim encoder decoder sharedEmbedding head.
{
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> SDim inputEmbedDim
edtInputEmbedDim :: SDim inputEmbedDim,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> encoder
edtEncoder :: encoder,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> decoder
edtDecoder :: decoder,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> sharedEmbedding
edtSharedEmbedding :: sharedEmbedding,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> head
edtHead :: head,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling :: EncoderDecoderTransformerHasEmbedScaling
} ->
GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head
deriving stock (Int
-> GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
Int
-> GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
[GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head]
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> String
showList :: [GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head]
-> ShowS
$cshowList :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
[GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head]
-> ShowS
show :: GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> String
$cshow :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> String
showsPrec :: Int
-> GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
$cshowsPrec :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
(Show encoder, Show decoder, Show sharedEmbedding, Show head) =>
Int
-> GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head x.
Rep
(GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head)
x
-> GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head x.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> Rep
(GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head)
x
$cto :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head x.
Rep
(GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head)
x
-> GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
$cfrom :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head x.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> Rep
(GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head)
x
Generic)
type instance
ModelSpec (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head) =
GEncoderDecoderTransformer inputEmbedDim (ModelSpec encoder) (ModelSpec decoder) (ModelSpec sharedEmbedding) (ModelSpec head)
type family
GEncoderDecoderTransformerF
(style :: TransformerStyle)
(transformerHead :: TransformerHead)
(numEncoderLayers :: Nat)
(numDecoderLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
GEncoderDecoderTransformerF style transformerHead numEncoderLayers numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim hasDropout =
GEncoderDecoderTransformer
inputEmbedDim
(EDTEncoderF style numEncoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)
(EDTDecoderF style numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)
(EDTSharedEmbeddingF style gradient device dataType inputEmbedDim vocabDim)
(EDTHeadF style transformerHead gradient device dataType inputEmbedDim vocabDim)
type family
EDTEncoderF
(style :: TransformerStyle)
(numEncoderLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
EDTEncoderF style numEncoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout =
NamedModel (TransformerEncoderF style numEncoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout)
type family
EDTDecoderF
(style :: TransformerStyle)
(numDecoderLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
EDTDecoderF style numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim hasDropout =
NamedModel (TransformerDecoderF style numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim inputEmbedDim ffnDim posEncDim hasDropout)
type family
EDTSharedEmbeddingF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
where
EDTSharedEmbeddingF _ gradient device dataType inputEmbedDim vocabDim =
NamedModel (Embedding gradient ('Layout 'Dense) device dataType vocabDim inputEmbedDim 'Nothing)
type family
EDTHeadF
(style :: TransformerStyle)
(transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
EDTHeadF style 'WithoutHead gradient device dataType inputEmbedDim vocabDim =
()
EDTHeadF style 'WithLMHead gradient device dataType inputEmbedDim vocabDim =
NamedModel (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
encoderDecoderTransformerSpec ::
forall style transformerHead numEncoderLayers numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim hasDropout.
STransformerStyle style ->
STransformerHead transformerHead ->
SNat numEncoderLayers ->
SNat numDecoderLayers ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim headDim ->
SDim headEmbedDim ->
SDim embedDim ->
SDim inputEmbedDim ->
SDim ffnDim ->
SDim posEncDim ->
SDim vocabDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (GEncoderDecoderTransformerF style transformerHead numEncoderLayers numDecoderLayers gradient device dataType headDim headEmbedDim embedDim inputEmbedDim ffnDim posEncDim vocabDim hasDropout)
encoderDecoderTransformerSpec :: forall (style :: TransformerStyle)
(transformerHead :: TransformerHead) (numEncoderLayers :: Nat)
(numDecoderLayers :: Nat) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> STransformerHead transformerHead
-> SNat numEncoderLayers
-> SNat numDecoderLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SDim vocabDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GEncoderDecoderTransformerF
style
transformerHead
numEncoderLayers
numDecoderLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
vocabDim
hasDropout)
encoderDecoderTransformerSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead SNat numEncoderLayers
numEncoderLayers SNat numDecoderLayers
numDecoderLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SDim vocabDim
vocabDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let encoderSpec :: STransformerStyle style
-> NamedModel
(GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'T5
ST5
encoderSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'ByT5
SByT5
encoderSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'BART
SBART
encoderSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'MBART
SMBART
encoderSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.encoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle 'Pegasus
SPegasus
encoderSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
encoderSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
encoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
decoderSpec :: STransformerStyle style
-> NamedModel
(GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout)))
decoderSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'T5
ST5
decoderSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'ByT5
SByT5
decoderSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'BART
SBART
decoderSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'MBART
SMBART
decoderSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.decoder." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle 'Pegasus
SPegasus
decoderSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
decoderSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
decoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
sharedEmbeddingSpec :: STransformerStyle style
-> NamedModel
(EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing)
sharedEmbeddingSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"shared." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
sharedEmbeddingSpec'
sharedEmbeddingSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"shared." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
sharedEmbeddingSpec'
sharedEmbeddingSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.shared." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
sharedEmbeddingSpec'
sharedEmbeddingSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.shared." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
sharedEmbeddingSpec'
sharedEmbeddingSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"model.shared." EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
sharedEmbeddingSpec'
sharedEmbeddingSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
sharedEmbeddingSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
sharedEmbeddingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
headSpec :: STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
(EDTHeadF
style
transformerHead
gradient
device
dataType
inputEmbedDim
vocabDim)
headSpec STransformerStyle style
ST5 STransformerHead transformerHead
SWithoutHead = ()
headSpec STransformerStyle style
ST5 STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'T5
ST5
headSpec STransformerStyle style
SByT5 STransformerHead transformerHead
SWithoutHead = ()
headSpec STransformerStyle style
SByT5 STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'ByT5
SByT5
headSpec STransformerStyle style
SBART STransformerHead transformerHead
SWithoutHead = ()
headSpec STransformerStyle style
SBART STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'BART
SBART
headSpec STransformerStyle style
SMBART STransformerHead transformerHead
SWithoutHead = ()
headSpec STransformerStyle style
SMBART STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'MBART
SMBART
headSpec STransformerStyle style
SPegasus STransformerHead transformerHead
SWithoutHead = ()
headSpec STransformerStyle style
SPegasus STransformerHead transformerHead
SWithLMHead = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {style :: TransformerStyle}.
STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle 'Pegasus
SPegasus
headSpec STransformerStyle style
SBERT STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
headSpec STransformerStyle style
SRoBERTa STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
headSpec STransformerStyle style
SGPT2 STransformerHead transformerHead
_ = forall a. HasCallStack => a
undefined
embedScalingSpec :: STransformerStyle style -> EncoderDecoderTransformerHasEmbedScaling
embedScalingSpec :: STransformerStyle style -> EncoderDecoderTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
ST5 = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
embedScalingSpec STransformerStyle style
SByT5 = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
embedScalingSpec STransformerStyle style
SBART = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
embedScalingSpec STransformerStyle style
SMBART = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling
embedScalingSpec STransformerStyle style
SPegasus = EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithEmbedScaling
embedScalingSpec STransformerStyle style
SBERT = forall a. HasCallStack => a
undefined
embedScalingSpec STransformerStyle style
SRoBERTa = forall a. HasCallStack => a
undefined
embedScalingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
SDim inputEmbedDim
-> encoder
-> decoder
-> sharedEmbedding
-> head
-> EncoderDecoderTransformerHasEmbedScaling
-> GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
GEncoderDecoderTransformer SDim inputEmbedDim
inputEmbedDim (STransformerStyle style
-> NamedModel
(GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout)))
encoderSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
(GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout)))
decoderSpec STransformerStyle style
style) (STransformerStyle style
-> NamedModel
(EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing)
sharedEmbeddingSpec STransformerStyle style
style) (STransformerStyle style
-> STransformerHead transformerHead
-> ModelSpec
(EDTHeadF
style
transformerHead
gradient
device
dataType
inputEmbedDim
vocabDim)
headSpec STransformerStyle style
style STransformerHead transformerHead
transformerHead) (STransformerStyle style -> EncoderDecoderTransformerHasEmbedScaling
embedScalingSpec STransformerStyle style
style)
where
encoderSpec' :: _
encoderSpec' :: STransformerStyle style
-> GTransformer
(ModelSpec
(TEPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TERelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TEInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numEncoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
()
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TEFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TEFinalDropoutF style hasDropout))
encoderSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim inputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(TransformerEncoderF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
inputEmbedDim
ffnDim
posEncDim
hasDropout)
transformerEncoderSpec STransformerStyle style
style' SNat numEncoderLayers
numEncoderLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
decoderSpec' :: _
decoderSpec' :: STransformerStyle style
-> GTransformer
(ModelSpec
(TDPosEncF style gradient device dataType inputEmbedDim posEncDim))
(ModelSpec
(TDRelPosEncF style gradient device dataType headDim posEncDim))
(ModelSpec
(TDInitialLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDInitialDropoutF style hasDropout))
(NamedModel
(GTransformerStack
(VectorSpec
numDecoderLayers
(GTransformerBlock
(NamedModel
(GSelfAttention
(SAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(SADropoutF style hasDropout)
(SAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GCrossAttention
(CAInitialLayerNormF style gradient device dataType inputEmbedDim)
(NamedModel
(GMultiHeadAttention
headDim
headEmbedDim
embedDim
(QInProjF style gradient device dataType inputEmbedDim embedDim)
(KInProjF style gradient device dataType inputEmbedDim embedDim)
(VInProjF style gradient device dataType inputEmbedDim embedDim)
(OutProjF style gradient device dataType embedDim inputEmbedDim)
(DropoutF style hasDropout)))
(CADropoutF style hasDropout)
(CAFinalLayerNormF style gradient device dataType inputEmbedDim)))
(NamedModel
(GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType inputEmbedDim)
(FFNInputTransformationF
style gradient device dataType inputEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF
style gradient device dataType inputEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF
style gradient device dataType inputEmbedDim)))))))
(ModelSpec
(TDFinalLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec (TDFinalDropoutF style hasDropout))
decoderSpec' STransformerStyle style
style' = forall (style :: TransformerStyle) (numLayers :: Nat)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(headDim :: Dim (Name Symbol) (Size Nat))
(headEmbedDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(decoderInputEmbedDim :: Dim (Name Symbol) (Size Nat))
(encoderOutputEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(posEncDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SNat numLayers
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim headDim
-> SDim headEmbedDim
-> SDim embedDim
-> SDim decoderInputEmbedDim
-> SDim encoderOutputEmbedDim
-> SDim ffnDim
-> SDim posEncDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(TransformerDecoderF
style
numLayers
gradient
device
dataType
headDim
headEmbedDim
embedDim
decoderInputEmbedDim
encoderOutputEmbedDim
ffnDim
posEncDim
hasDropout)
transformerDecoderSpec STransformerStyle style
style' SNat numDecoderLayers
numDecoderLayers SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim headDim
headDim SDim headEmbedDim
headEmbedDim SDim embedDim
embedDim SDim inputEmbedDim
inputEmbedDim SDim inputEmbedDim
inputEmbedDim SDim ffnDim
ffnDim SDim posEncDim
posEncDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps
sharedEmbeddingSpec' :: EmbeddingSpec
gradient
('Layout 'Dense)
device
dataType
vocabDim
inputEmbedDim
'Nothing
sharedEmbeddingSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(embedNumDim :: Dim (Name Symbol) (Size Nat))
(embedDim :: Dim (Name Symbol) (Size Nat))
(paddingIdx :: Maybe Nat).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SDim embedNumDim
-> SDim embedDim
-> SMaybe paddingIdx
-> EmbeddingSpec
gradient layout device dataType embedNumDim embedDim paddingIdx
EmbeddingSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SDim vocabDim
vocabDim SDim inputEmbedDim
inputEmbedDim forall a. SMaybe 'Nothing
SNothing
headSpec' :: _
headSpec' :: STransformerStyle style
-> GLMHead
inputEmbedDim
(ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim))
(ModelSpec (LMHeadActivationF style))
(ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim))
(ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim))
(ModelSpec (LMHeadBiasF style gradient device dataType vocabDim))
headSpec' STransformerStyle style
style' = forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputEmbedDim
-> SDim vocabDim
-> Double
-> ModelSpec
(GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec STransformerStyle style
style' SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim Double
eps
instance
( HasInitialize encoder generatorDevice encoder' generatorDevice0,
HasInitialize decoder generatorDevice0 decoder' generatorDevice1,
HasInitialize sharedEmbedding generatorDevice1 sharedEmbedding' generatorDevice2,
HasInitialize head generatorDevice2 head' generatorOutputDevice
) =>
HasInitialize
(GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
generatorDevice
(GEncoderDecoderTransformer inputEmbedDim encoder' decoder' sharedEmbedding' head')
generatorOutputDevice
instance
( HasStateDict encoder,
HasStateDict decoder,
HasStateDict sharedEmbedding,
HasStateDict head
) =>
HasStateDict (GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
data
GSimplifiedEncoderDecoderTransformer
(model :: Type)
(mkPos :: Type)
(mkDecoderPos :: Type)
(mkPaddingMask :: Type)
(mkAttentionMask :: Type)
(mkCrossAttentionMask :: Type)
(mkDecoderAttentionMask :: Type)
where
GSimplifiedEncoderDecoderTransformer ::
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask.
{
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> model
sedtModel :: model,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShiftRight Int
sedtDecoderInputShift :: ShiftRight Int,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShiftRight Int
sedtPaddingMaskShift :: ShiftRight Int,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkPos
sedtMkPos :: mkPos,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkDecoderPos
sedtMkDecoderPos :: mkDecoderPos,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkPaddingMask
sedtMkPaddingMask :: mkPaddingMask,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkAttentionMask
sedtMkAttentionMask :: mkAttentionMask,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkCrossAttentionMask
sedtMkCrossAttentionMask :: mkCrossAttentionMask,
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkDecoderAttentionMask
sedtMkDecoderAttentionMask :: mkDecoderAttentionMask
} ->
GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask
deriving stock (GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Eq model, Eq mkPos, Eq mkDecoderPos, Eq mkPaddingMask,
Eq mkAttentionMask, Eq mkCrossAttentionMask,
Eq mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
/= :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
$c/= :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Eq model, Eq mkPos, Eq mkDecoderPos, Eq mkPaddingMask,
Eq mkAttentionMask, Eq mkCrossAttentionMask,
Eq mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
== :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
$c== :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Eq model, Eq mkPos, Eq mkDecoderPos, Eq mkPaddingMask,
Eq mkAttentionMask, Eq mkCrossAttentionMask,
Eq mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
Eq, GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {model} {mkPos} {mkDecoderPos} {mkPaddingMask}
{mkAttentionMask} {mkCrossAttentionMask} {mkDecoderAttentionMask}.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
Eq
(GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask)
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Ordering
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
min :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
$cmin :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
max :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
$cmax :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
>= :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
$c>= :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
> :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
$c> :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
<= :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
$c<= :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
< :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
$c< :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Bool
compare :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Ordering
$ccompare :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Ord model, Ord mkPos, Ord mkDecoderPos, Ord mkPaddingMask,
Ord mkAttentionMask, Ord mkCrossAttentionMask,
Ord mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Ordering
Ord, Int
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
Show mkAttentionMask, Show mkCrossAttentionMask,
Show mkDecoderAttentionMask) =>
Int
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShowS
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
Show mkAttentionMask, Show mkCrossAttentionMask,
Show mkDecoderAttentionMask) =>
[GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask]
-> ShowS
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
Show mkAttentionMask, Show mkCrossAttentionMask,
Show mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> String
showList :: [GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask]
-> ShowS
$cshowList :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
Show mkAttentionMask, Show mkCrossAttentionMask,
Show mkDecoderAttentionMask) =>
[GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask]
-> ShowS
show :: GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> String
$cshow :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
Show mkAttentionMask, Show mkCrossAttentionMask,
Show mkDecoderAttentionMask) =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> String
showsPrec :: Int
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShowS
$cshowsPrec :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
(Show model, Show mkPos, Show mkDecoderPos, Show mkPaddingMask,
Show mkAttentionMask, Show mkCrossAttentionMask,
Show mkDecoderAttentionMask) =>
Int
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask x.
Rep
(GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask)
x
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask x.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Rep
(GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask)
x
$cto :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask x.
Rep
(GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask)
x
-> GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
$cfrom :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask x.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> Rep
(GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask)
x
Generic)
type instance
ModelSpec (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask) =
GSimplifiedEncoderDecoderTransformer (ModelSpec model) (ModelSpec mkPos) (ModelSpec mkDecoderPos) (ModelSpec mkPaddingMask) (ModelSpec mkAttentionMask) (ModelSpec mkCrossAttentionMask) (ModelSpec mkDecoderAttentionMask)
instance
( HasInitialize model generatorDevice model' generatorDevice0,
HasInitialize mkPos generatorDevice0 mkPos' generatorDevice1,
HasInitialize mkDecoderPos generatorDevice1 mkDecoderPos' generatorDevice2,
HasInitialize mkPaddingMask generatorDevice2 mkPaddingMask' generatorDevice3,
HasInitialize mkAttentionMask generatorDevice3 mkAttentionMask' generatorDevice4,
HasInitialize mkCrossAttentionMask generatorDevice4 mkCrossAttentionMask' generatorDevice5,
HasInitialize mkDecoderAttentionMask generatorDevice5 mkDecoderAttentionMask' generatorOutputDevice
) =>
HasInitialize
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
generatorDevice
(GSimplifiedEncoderDecoderTransformer model' mkPos' mkDecoderPos' mkPaddingMask' mkAttentionMask' mkCrossAttentionMask' mkDecoderAttentionMask')
generatorOutputDevice
instance
( HasStateDict model,
HasStateDict mkPos,
HasStateDict mkDecoderPos,
HasStateDict mkPaddingMask,
HasStateDict mkAttentionMask,
HasStateDict mkCrossAttentionMask,
HasStateDict mkDecoderAttentionMask
) =>
HasStateDict (GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
data EncoderDecoderTransformerInput input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask where
EncoderDecoderTransformerInput ::
forall input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask.
{ forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> input
edtInput :: input,
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> decoderInput
edtDecoderInput :: decoderInput,
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> pos
edtPos :: pos,
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> decoderPos
edtDecoderPos :: decoderPos,
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> attentionMask
edtAttentionMask :: attentionMask,
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> decoderAttentionMask
edtDecoderAttentionMask :: decoderAttentionMask,
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> crossAttentionMask
edtCrossAttentionMask :: crossAttentionMask
} ->
EncoderDecoderTransformerInput input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask
deriving stock (EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Eq input, Eq decoderInput, Eq pos, Eq decoderPos,
Eq attentionMask, Eq decoderAttentionMask,
Eq crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
/= :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
$c/= :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Eq input, Eq decoderInput, Eq pos, Eq decoderPos,
Eq attentionMask, Eq decoderAttentionMask,
Eq crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
== :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
$c== :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Eq input, Eq decoderInput, Eq pos, Eq decoderPos,
Eq attentionMask, Eq decoderAttentionMask,
Eq crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
Eq, EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {decoderInput} {pos} {decoderPos} {attentionMask}
{decoderAttentionMask} {crossAttentionMask}.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
Eq
(EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask)
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Ordering
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
min :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
$cmin :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
max :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
$cmax :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
>= :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
$c>= :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
> :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
$c> :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
<= :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
$c<= :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
< :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
$c< :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Bool
compare :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Ordering
$ccompare :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Ord input, Ord decoderInput, Ord pos, Ord decoderPos,
Ord attentionMask, Ord decoderAttentionMask,
Ord crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Ordering
Ord, Int
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
Show attentionMask, Show decoderAttentionMask,
Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> ShowS
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
Show attentionMask, Show decoderAttentionMask,
Show crossAttentionMask) =>
[EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask]
-> ShowS
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
Show attentionMask, Show decoderAttentionMask,
Show crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> String
showList :: [EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask]
-> ShowS
$cshowList :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
Show attentionMask, Show decoderAttentionMask,
Show crossAttentionMask) =>
[EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask]
-> ShowS
show :: EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> String
$cshow :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
Show attentionMask, Show decoderAttentionMask,
Show crossAttentionMask) =>
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> String
showsPrec :: Int
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> ShowS
$cshowsPrec :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
(Show input, Show decoderInput, Show pos, Show decoderPos,
Show attentionMask, Show decoderAttentionMask,
Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask x.
Rep
(EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask)
x
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask x.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Rep
(EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask)
x
$cto :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask x.
Rep
(EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask)
x
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
$cfrom :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask x.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Rep
(EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask)
x
Generic)
data EncoderDecoderTransformerInput' input pos attentionMask where
EncoderDecoderTransformerInput' ::
forall input pos attentionMask.
{ forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> input
edtInput' :: input,
forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> pos
edtPos' :: pos,
forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask
-> attentionMask
edtAttentionMask' :: attentionMask
} ->
EncoderDecoderTransformerInput' input pos attentionMask
deriving stock (EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input pos attentionMask.
(Eq input, Eq pos, Eq attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
/= :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c/= :: forall input pos attentionMask.
(Eq input, Eq pos, Eq attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
== :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c== :: forall input pos attentionMask.
(Eq input, Eq pos, Eq attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
Eq, EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {pos} {attentionMask}.
(Ord input, Ord pos, Ord attentionMask) =>
Eq (EncoderDecoderTransformerInput' input pos attentionMask)
forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
min :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
$cmin :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
max :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
$cmax :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
>= :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c>= :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
> :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c> :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
<= :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c<= :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
< :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
$c< :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask -> Bool
compare :: EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
$ccompare :: forall input pos attentionMask.
(Ord input, Ord pos, Ord attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Ordering
Ord, Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
[EncoderDecoderTransformerInput' input pos attentionMask] -> ShowS
forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask -> String
showList :: [EncoderDecoderTransformerInput' input pos attentionMask] -> ShowS
$cshowList :: forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
[EncoderDecoderTransformerInput' input pos attentionMask] -> ShowS
show :: EncoderDecoderTransformerInput' input pos attentionMask -> String
$cshow :: forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
EncoderDecoderTransformerInput' input pos attentionMask -> String
showsPrec :: Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
$cshowsPrec :: forall input pos attentionMask.
(Show input, Show pos, Show attentionMask) =>
Int
-> EncoderDecoderTransformerInput' input pos attentionMask -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input pos attentionMask x.
Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
-> EncoderDecoderTransformerInput' input pos attentionMask
forall input pos attentionMask x.
EncoderDecoderTransformerInput' input pos attentionMask
-> Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
$cto :: forall input pos attentionMask x.
Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
-> EncoderDecoderTransformerInput' input pos attentionMask
$cfrom :: forall input pos attentionMask x.
EncoderDecoderTransformerInput' input pos attentionMask
-> Rep (EncoderDecoderTransformerInput' input pos attentionMask) x
Generic)
data SimplifiedEncoderDecoderTransformerInput input decoderInput where
SimplifiedEncoderDecoderTransformerInput ::
forall input decoderInput.
{ forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> input
sedtInput :: input,
forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> decoderInput
sedtDecoderInput :: decoderInput
} ->
SimplifiedEncoderDecoderTransformerInput input decoderInput
deriving stock (SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input decoderInput.
(Eq input, Eq decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
/= :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c/= :: forall input decoderInput.
(Eq input, Eq decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
== :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c== :: forall input decoderInput.
(Eq input, Eq decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
Eq, SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {decoderInput}.
(Ord input, Ord decoderInput) =>
Eq (SimplifiedEncoderDecoderTransformerInput input decoderInput)
forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
min :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
$cmin :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
max :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
$cmax :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
>= :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c>= :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
> :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c> :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
<= :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c<= :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
< :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
$c< :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Bool
compare :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
$ccompare :: forall input decoderInput.
(Ord input, Ord decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input decoderInput.
(Show input, Show decoderInput) =>
Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
forall input decoderInput.
(Show input, Show decoderInput) =>
[SimplifiedEncoderDecoderTransformerInput input decoderInput]
-> ShowS
forall input decoderInput.
(Show input, Show decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> String
showList :: [SimplifiedEncoderDecoderTransformerInput input decoderInput]
-> ShowS
$cshowList :: forall input decoderInput.
(Show input, Show decoderInput) =>
[SimplifiedEncoderDecoderTransformerInput input decoderInput]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerInput input decoderInput
-> String
$cshow :: forall input decoderInput.
(Show input, Show decoderInput) =>
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
$cshowsPrec :: forall input decoderInput.
(Show input, Show decoderInput) =>
Int
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input decoderInput x.
Rep (SimplifiedEncoderDecoderTransformerInput input decoderInput) x
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
forall input decoderInput x.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Rep
(SimplifiedEncoderDecoderTransformerInput input decoderInput) x
$cto :: forall input decoderInput x.
Rep (SimplifiedEncoderDecoderTransformerInput input decoderInput) x
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
$cfrom :: forall input decoderInput x.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Rep
(SimplifiedEncoderDecoderTransformerInput input decoderInput) x
Generic)
data SimplifiedEncoderDecoderTransformerInput' input where
SimplifiedEncoderDecoderTransformerInput' ::
forall input.
{ forall input.
SimplifiedEncoderDecoderTransformerInput' input -> input
sedtInput' :: input
} ->
SimplifiedEncoderDecoderTransformerInput' input
deriving stock (SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
forall input.
Eq input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c/= :: forall input.
Eq input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
== :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c== :: forall input.
Eq input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
Eq, SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input}.
Ord input =>
Eq (SimplifiedEncoderDecoderTransformerInput' input)
forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
min :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
$cmin :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
max :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
$cmax :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input
>= :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c>= :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
> :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c> :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
<= :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c<= :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
< :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
$c< :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Bool
compare :: SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
$ccompare :: forall input.
Ord input =>
SimplifiedEncoderDecoderTransformerInput' input
-> SimplifiedEncoderDecoderTransformerInput' input -> Ordering
Ord, Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
forall input.
Show input =>
Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
forall input.
Show input =>
[SimplifiedEncoderDecoderTransformerInput' input] -> ShowS
forall input.
Show input =>
SimplifiedEncoderDecoderTransformerInput' input -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SimplifiedEncoderDecoderTransformerInput' input] -> ShowS
$cshowList :: forall input.
Show input =>
[SimplifiedEncoderDecoderTransformerInput' input] -> ShowS
show :: SimplifiedEncoderDecoderTransformerInput' input -> String
$cshow :: forall input.
Show input =>
SimplifiedEncoderDecoderTransformerInput' input -> String
showsPrec :: Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
$cshowsPrec :: forall input.
Show input =>
Int -> SimplifiedEncoderDecoderTransformerInput' input -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input x.
Rep (SimplifiedEncoderDecoderTransformerInput' input) x
-> SimplifiedEncoderDecoderTransformerInput' input
forall input x.
SimplifiedEncoderDecoderTransformerInput' input
-> Rep (SimplifiedEncoderDecoderTransformerInput' input) x
$cto :: forall input x.
Rep (SimplifiedEncoderDecoderTransformerInput' input) x
-> SimplifiedEncoderDecoderTransformerInput' input
$cfrom :: forall input x.
SimplifiedEncoderDecoderTransformerInput' input
-> Rep (SimplifiedEncoderDecoderTransformerInput' input) x
Generic)
data SimplifiedEncoderDecoderTransformerTrainingInput input target where
SimplifiedEncoderDecoderTransformerTrainingInput ::
forall input target.
{ forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> input
sedtTrainingInput :: input,
forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> target
sedtTarget :: target
} ->
SimplifiedEncoderDecoderTransformerTrainingInput input target
deriving stock (SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall input target.
(Eq input, Eq target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
/= :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c/= :: forall input target.
(Eq input, Eq target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
== :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c== :: forall input target.
(Eq input, Eq target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
Eq, SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {input} {target}.
(Ord input, Ord target) =>
Eq (SimplifiedEncoderDecoderTransformerTrainingInput input target)
forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
min :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
$cmin :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
max :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
$cmax :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
>= :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c>= :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
> :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c> :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
<= :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c<= :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
< :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
$c< :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Bool
compare :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
$ccompare :: forall input target.
(Ord input, Ord target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall input target.
(Show input, Show target) =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
forall input target.
(Show input, Show target) =>
[SimplifiedEncoderDecoderTransformerTrainingInput input target]
-> ShowS
forall input target.
(Show input, Show target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> String
showList :: [SimplifiedEncoderDecoderTransformerTrainingInput input target]
-> ShowS
$cshowList :: forall input target.
(Show input, Show target) =>
[SimplifiedEncoderDecoderTransformerTrainingInput input target]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerTrainingInput input target
-> String
$cshow :: forall input target.
(Show input, Show target) =>
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
$cshowsPrec :: forall input target.
(Show input, Show target) =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall input target x.
Rep
(SimplifiedEncoderDecoderTransformerTrainingInput input target) x
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
forall input target x.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Rep
(SimplifiedEncoderDecoderTransformerTrainingInput input target) x
$cto :: forall input target x.
Rep
(SimplifiedEncoderDecoderTransformerTrainingInput input target) x
-> SimplifiedEncoderDecoderTransformerTrainingInput input target
$cfrom :: forall input target x.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> Rep
(SimplifiedEncoderDecoderTransformerTrainingInput input target) x
Generic)
data EncoderDecoderTransformerOutput decoderOutput encoderOutput where
EncoderDecoderTransformerOutput ::
forall decoderOutput encoderOutput.
{ forall decoderOutput encoderOutput.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> decoderOutput
edtDecoderOutput :: decoderOutput,
forall decoderOutput encoderOutput.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> encoderOutput
edtEncoderOutput :: encoderOutput
} ->
EncoderDecoderTransformerOutput decoderOutput encoderOutput
deriving stock (EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderOutput encoderOutput.
(Eq decoderOutput, Eq encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
/= :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c/= :: forall decoderOutput encoderOutput.
(Eq decoderOutput, Eq encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
== :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c== :: forall decoderOutput encoderOutput.
(Eq decoderOutput, Eq encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
Eq, EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderOutput} {encoderOutput}.
(Ord decoderOutput, Ord encoderOutput) =>
Eq (EncoderDecoderTransformerOutput decoderOutput encoderOutput)
forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
min :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
$cmin :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
max :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
$cmax :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
>= :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c>= :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
> :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c> :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
<= :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c<= :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
< :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
$c< :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Bool
compare :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
$ccompare :: forall decoderOutput encoderOutput.
(Ord decoderOutput, Ord encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Ordering
Ord, Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
[EncoderDecoderTransformerOutput decoderOutput encoderOutput]
-> ShowS
forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> String
showList :: [EncoderDecoderTransformerOutput decoderOutput encoderOutput]
-> ShowS
$cshowList :: forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
[EncoderDecoderTransformerOutput decoderOutput encoderOutput]
-> ShowS
show :: EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> String
$cshow :: forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> String
showsPrec :: Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
$cshowsPrec :: forall decoderOutput encoderOutput.
(Show decoderOutput, Show encoderOutput) =>
Int
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderOutput encoderOutput x.
Rep (EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
forall decoderOutput encoderOutput x.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Rep
(EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
$cto :: forall decoderOutput encoderOutput x.
Rep (EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
$cfrom :: forall decoderOutput encoderOutput x.
EncoderDecoderTransformerOutput decoderOutput encoderOutput
-> Rep
(EncoderDecoderTransformerOutput decoderOutput encoderOutput) x
Generic)
data EncoderDecoderTransformerOutput' encoderOutput where
EncoderDecoderTransformerOutput' ::
forall encoderOutput.
{ forall encoderOutput.
EncoderDecoderTransformerOutput' encoderOutput -> encoderOutput
edtEncoderOutput' :: encoderOutput
} ->
EncoderDecoderTransformerOutput' encoderOutput
deriving stock (EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
forall encoderOutput.
Eq encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c/= :: forall encoderOutput.
Eq encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
== :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c== :: forall encoderOutput.
Eq encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
Eq, EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {encoderOutput}.
Ord encoderOutput =>
Eq (EncoderDecoderTransformerOutput' encoderOutput)
forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
min :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
$cmin :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
max :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
$cmax :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput
>= :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c>= :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
> :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c> :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
<= :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c<= :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
< :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
$c< :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Bool
compare :: EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
$ccompare :: forall encoderOutput.
Ord encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput
-> EncoderDecoderTransformerOutput' encoderOutput -> Ordering
Ord, Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
forall encoderOutput.
Show encoderOutput =>
Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
forall encoderOutput.
Show encoderOutput =>
[EncoderDecoderTransformerOutput' encoderOutput] -> ShowS
forall encoderOutput.
Show encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncoderDecoderTransformerOutput' encoderOutput] -> ShowS
$cshowList :: forall encoderOutput.
Show encoderOutput =>
[EncoderDecoderTransformerOutput' encoderOutput] -> ShowS
show :: EncoderDecoderTransformerOutput' encoderOutput -> String
$cshow :: forall encoderOutput.
Show encoderOutput =>
EncoderDecoderTransformerOutput' encoderOutput -> String
showsPrec :: Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
$cshowsPrec :: forall encoderOutput.
Show encoderOutput =>
Int -> EncoderDecoderTransformerOutput' encoderOutput -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall encoderOutput x.
Rep (EncoderDecoderTransformerOutput' encoderOutput) x
-> EncoderDecoderTransformerOutput' encoderOutput
forall encoderOutput x.
EncoderDecoderTransformerOutput' encoderOutput
-> Rep (EncoderDecoderTransformerOutput' encoderOutput) x
$cto :: forall encoderOutput x.
Rep (EncoderDecoderTransformerOutput' encoderOutput) x
-> EncoderDecoderTransformerOutput' encoderOutput
$cfrom :: forall encoderOutput x.
EncoderDecoderTransformerOutput' encoderOutput
-> Rep (EncoderDecoderTransformerOutput' encoderOutput) x
Generic)
data SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask where
SimplifiedEncoderDecoderTransformerOutput ::
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
{ forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderOutput
sedtDecoderOutput :: decoderOutput,
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> encoderOutput
sedtEncoderOutput :: encoderOutput,
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderInput
sedtOriginalDecoderInput :: decoderInput,
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> inputPaddingMask
sedtInputPaddingMask :: inputPaddingMask
} ->
SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask
deriving stock (SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Eq decoderOutput, Eq encoderOutput, Eq decoderInput,
Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
/= :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c/= :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Eq decoderOutput, Eq encoderOutput, Eq decoderInput,
Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
== :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c== :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Eq decoderOutput, Eq encoderOutput, Eq decoderInput,
Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
Eq, SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderOutput} {encoderOutput} {decoderInput}
{inputPaddingMask}.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
Eq
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
min :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
$cmin :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
max :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
$cmax :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
>= :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c>= :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
> :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c> :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
<= :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c<= :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
< :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
$c< :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Bool
compare :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
$ccompare :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Ord decoderOutput, Ord encoderOutput, Ord decoderInput,
Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask]
-> ShowS
forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> String
showList :: [SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask]
-> ShowS
$cshowList :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> String
$cshow :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
$cshowsPrec :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
(Show decoderOutput, Show encoderOutput, Show decoderInput,
Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
Rep
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
x
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Rep
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
x
$cto :: forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
Rep
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
x
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
$cfrom :: forall decoderOutput encoderOutput decoderInput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> Rep
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
x
Generic)
data SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask where
SimplifiedEncoderDecoderTransformerOutput' ::
forall encoderOutput inputPaddingMask.
{ forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> encoderOutput
sedtEncoderOutput' :: encoderOutput,
forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> inputPaddingMask
sedtInputPaddingMask' :: inputPaddingMask
} ->
SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask
deriving stock (SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall encoderOutput inputPaddingMask.
(Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
/= :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
$c/= :: forall encoderOutput inputPaddingMask.
(Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
== :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
$c== :: forall encoderOutput inputPaddingMask.
(Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
Eq, SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {encoderOutput} {inputPaddingMask}.
(Ord encoderOutput, Ord inputPaddingMask) =>
Eq
(SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask)
forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Ordering
forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
min :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
$cmin :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
max :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
$cmax :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
>= :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
$c>= :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
> :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
$c> :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
<= :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
$c<= :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
< :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
$c< :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Bool
compare :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Ordering
$ccompare :: forall encoderOutput inputPaddingMask.
(Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> ShowS
forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask]
-> ShowS
forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> String
showList :: [SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask]
-> ShowS
$cshowList :: forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> String
$cshow :: forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> ShowS
$cshowsPrec :: forall encoderOutput inputPaddingMask.
(Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall encoderOutput inputPaddingMask x.
Rep
(SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask)
x
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
forall encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Rep
(SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask)
x
$cto :: forall encoderOutput inputPaddingMask x.
Rep
(SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask)
x
-> SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
$cfrom :: forall encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> Rep
(SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask)
x
Generic)
data SimplifiedEncoderDecoderTransformerTrainingOutput loss where
SimplifiedEncoderDecoderTransformerTrainingOutput ::
forall loss.
{ forall loss.
SimplifiedEncoderDecoderTransformerTrainingOutput loss -> loss
sedtLoss :: loss
} ->
SimplifiedEncoderDecoderTransformerTrainingOutput loss
deriving stock (SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
forall loss.
Eq loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c/= :: forall loss.
Eq loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
== :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c== :: forall loss.
Eq loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
Eq, SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {loss}.
Ord loss =>
Eq (SimplifiedEncoderDecoderTransformerTrainingOutput loss)
forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
min :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
$cmin :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
max :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
$cmax :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
>= :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c>= :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
> :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c> :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
<= :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c<= :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
< :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
$c< :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> Bool
compare :: SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
$ccompare :: forall loss.
Ord loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
forall loss.
Show loss =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
forall loss.
Show loss =>
[SimplifiedEncoderDecoderTransformerTrainingOutput loss] -> ShowS
forall loss.
Show loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SimplifiedEncoderDecoderTransformerTrainingOutput loss] -> ShowS
$cshowList :: forall loss.
Show loss =>
[SimplifiedEncoderDecoderTransformerTrainingOutput loss] -> ShowS
show :: SimplifiedEncoderDecoderTransformerTrainingOutput loss -> String
$cshow :: forall loss.
Show loss =>
SimplifiedEncoderDecoderTransformerTrainingOutput loss -> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
$cshowsPrec :: forall loss.
Show loss =>
Int
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall loss x.
Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
forall loss x.
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
$cto :: forall loss x.
Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
-> SimplifiedEncoderDecoderTransformerTrainingOutput loss
$cfrom :: forall loss x.
SimplifiedEncoderDecoderTransformerTrainingOutput loss
-> Rep (SimplifiedEncoderDecoderTransformerTrainingOutput loss) x
Generic)
data EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask where
EncoderDecoderTransformerGenerationInput ::
forall decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask.
{ forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> decoderInput
edtGenerationDecoderInput :: decoderInput,
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> encoderOutput
edtGenerationEncoderOutput :: encoderOutput,
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> decoderPos
edtGenerationDecoderPos :: decoderPos,
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> decoderAttentionMask
edtGenerationDecoderAttentionMask :: decoderAttentionMask,
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> crossAttentionMask
edtGenerationCrossAttentionMask :: crossAttentionMask
} ->
EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask
deriving stock (EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Eq decoderInput, Eq encoderOutput, Eq decoderPos,
Eq decoderAttentionMask, Eq crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
/= :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
$c/= :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Eq decoderInput, Eq encoderOutput, Eq decoderPos,
Eq decoderAttentionMask, Eq crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
== :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
$c== :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Eq decoderInput, Eq encoderOutput, Eq decoderPos,
Eq decoderAttentionMask, Eq crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
Eq, EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderInput} {encoderOutput} {decoderPos}
{decoderAttentionMask} {crossAttentionMask}.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
Eq
(EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask)
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Ordering
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
min :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
$cmin :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
max :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
$cmax :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
>= :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
$c>= :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
> :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
$c> :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
<= :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
$c<= :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
< :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
$c< :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Bool
compare :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Ordering
$ccompare :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Ord decoderInput, Ord encoderOutput, Ord decoderPos,
Ord decoderAttentionMask, Ord crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Ordering
Ord, Int
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
Show decoderAttentionMask, Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> ShowS
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
Show decoderAttentionMask, Show crossAttentionMask) =>
[EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask]
-> ShowS
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
Show decoderAttentionMask, Show crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> String
showList :: [EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask]
-> ShowS
$cshowList :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
Show decoderAttentionMask, Show crossAttentionMask) =>
[EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask]
-> ShowS
show :: EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> String
$cshow :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
Show decoderAttentionMask, Show crossAttentionMask) =>
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> String
showsPrec :: Int
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> ShowS
$cshowsPrec :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
(Show decoderInput, Show encoderOutput, Show decoderPos,
Show decoderAttentionMask, Show crossAttentionMask) =>
Int
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask x.
Rep
(EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask)
x
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask x.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Rep
(EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask)
x
$cto :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask x.
Rep
(EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask)
x
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
$cfrom :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask x.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Rep
(EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask)
x
Generic)
data SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask where
SimplifiedEncoderDecoderTransformerGenerationInput ::
forall decoderInput encoderOutput inputPaddingMask.
{ forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> decoderInput
sedtGenerationDecoderInput :: decoderInput,
forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> encoderOutput
sedtGenerationEncoderOutput :: encoderOutput,
forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> inputPaddingMask
sedtGenerationInputPaddingMask :: inputPaddingMask
} ->
SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask
deriving stock (SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall decoderInput encoderOutput inputPaddingMask.
(Eq decoderInput, Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
/= :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
$c/= :: forall decoderInput encoderOutput inputPaddingMask.
(Eq decoderInput, Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
== :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
$c== :: forall decoderInput encoderOutput inputPaddingMask.
(Eq decoderInput, Eq encoderOutput, Eq inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
Eq, SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {decoderInput} {encoderOutput} {inputPaddingMask}.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
Eq
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask)
forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Ordering
forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
min :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
$cmin :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
max :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
$cmax :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
>= :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
$c>= :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
> :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
$c> :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
<= :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
$c<= :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
< :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
$c< :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Bool
compare :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Ordering
$ccompare :: forall decoderInput encoderOutput inputPaddingMask.
(Ord decoderInput, Ord encoderOutput, Ord inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Ordering
Ord, Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> ShowS
forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask]
-> ShowS
forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> String
showList :: [SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask]
-> ShowS
$cshowList :: forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
[SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask]
-> ShowS
show :: SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> String
$cshow :: forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> String
showsPrec :: Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> ShowS
$cshowsPrec :: forall decoderInput encoderOutput inputPaddingMask.
(Show decoderInput, Show encoderOutput, Show inputPaddingMask) =>
Int
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall decoderInput encoderOutput inputPaddingMask x.
Rep
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask)
x
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
forall decoderInput encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Rep
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask)
x
$cto :: forall decoderInput encoderOutput inputPaddingMask x.
Rep
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask)
x
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
$cfrom :: forall decoderInput encoderOutput inputPaddingMask x.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Rep
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask)
x
Generic)
instance
( HasForward
(GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
(EncoderDecoderTransformerInput' input pos attentionMask)
generatorDevice
(EncoderDecoderTransformerOutput' encoderOutput)
generatorDevice',
HasForward
(GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
(EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask)
generatorDevice'
(EncoderDecoderTransformerOutput headOutput encoderOutput)
generatorOutputDevice
) =>
HasForward
(GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
(EncoderDecoderTransformerInput input decoderInput pos decoderPos attentionMask decoderAttentionMask crossAttentionMask)
generatorDevice
(EncoderDecoderTransformerOutput headOutput encoderOutput)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> Generator generatorDevice
-> m (EncoderDecoderTransformerOutput headOutput encoderOutput,
Generator generatorOutputDevice)
forward GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
model EncoderDecoderTransformerInput {input
pos
attentionMask
decoderInput
decoderPos
decoderAttentionMask
crossAttentionMask
edtCrossAttentionMask :: crossAttentionMask
edtDecoderAttentionMask :: decoderAttentionMask
edtAttentionMask :: attentionMask
edtDecoderPos :: decoderPos
edtPos :: pos
edtDecoderInput :: decoderInput
edtInput :: input
edtCrossAttentionMask :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> crossAttentionMask
edtDecoderAttentionMask :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> decoderAttentionMask
edtAttentionMask :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> attentionMask
edtDecoderPos :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> decoderPos
edtPos :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> pos
edtDecoderInput :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> decoderInput
edtInput :: forall input decoderInput pos decoderPos attentionMask
decoderAttentionMask crossAttentionMask.
EncoderDecoderTransformerInput
input
decoderInput
pos
decoderPos
attentionMask
decoderAttentionMask
crossAttentionMask
-> input
..} =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT
( forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
model
EncoderDecoderTransformerInput'
{ edtInput' :: input
edtInput' = input
edtInput,
edtPos' :: pos
edtPos' = pos
edtPos,
edtAttentionMask' :: attentionMask
edtAttentionMask' = attentionMask
edtAttentionMask
}
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \EncoderDecoderTransformerOutput' {encoderOutput
edtEncoderOutput' :: encoderOutput
edtEncoderOutput' :: forall encoderOutput.
EncoderDecoderTransformerOutput' encoderOutput -> encoderOutput
..} ->
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT
( forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
model
EncoderDecoderTransformerGenerationInput
{ edtGenerationDecoderInput :: decoderInput
edtGenerationDecoderInput = decoderInput
edtDecoderInput,
edtGenerationEncoderOutput :: encoderOutput
edtGenerationEncoderOutput = encoderOutput
edtEncoderOutput',
edtGenerationDecoderPos :: decoderPos
edtGenerationDecoderPos = decoderPos
edtDecoderPos,
edtGenerationDecoderAttentionMask :: decoderAttentionMask
edtGenerationDecoderAttentionMask = decoderAttentionMask
edtDecoderAttentionMask,
edtGenerationCrossAttentionMask :: crossAttentionMask
edtGenerationCrossAttentionMask = crossAttentionMask
edtCrossAttentionMask
}
)
)
instance
( HasForward
sharedEmbedding
input
generatorDevice
embeddingOutput
embeddingGeneratorOutputDevice,
embeddingOutput ~ Tensor requiresGradient' layout' device' dataType' shape',
HasForward
encoder
(embeddingOutput, pos, attentionMask)
embeddingGeneratorOutputDevice
encoderOutput
generatorOutputDevice
) =>
HasForward
(GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
(EncoderDecoderTransformerInput' input pos attentionMask)
generatorDevice
(EncoderDecoderTransformerOutput' encoderOutput)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerInput' input pos attentionMask
-> Generator generatorDevice
-> m (EncoderDecoderTransformerOutput' encoderOutput,
Generator generatorOutputDevice)
forward GEncoderDecoderTransformer {sharedEmbedding
encoder
decoder
head
SDim inputEmbedDim
EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling :: EncoderDecoderTransformerHasEmbedScaling
edtHead :: head
edtSharedEmbedding :: sharedEmbedding
edtDecoder :: decoder
edtEncoder :: encoder
edtInputEmbedDim :: SDim inputEmbedDim
edtEmbedScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerHasEmbedScaling
edtHead :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> head
edtSharedEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> sharedEmbedding
edtDecoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> decoder
edtEncoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> encoder
edtInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> SDim inputEmbedDim
..} EncoderDecoderTransformerInput' {input
pos
attentionMask
edtAttentionMask' :: attentionMask
edtPos' :: pos
edtInput' :: input
edtAttentionMask' :: forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask
-> attentionMask
edtPos' :: forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> pos
edtInput' :: forall input pos attentionMask.
EncoderDecoderTransformerInput' input pos attentionMask -> input
..} =
let Double
scaling :: Double = forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
edtInputEmbedDim
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn input
edtInput'
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward sharedEmbedding
edtSharedEmbedding
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
)
EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor requiresGradient' layout' device' dataType' shape'
input' -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward encoder
edtEncoder (Tensor requiresGradient' layout' device' dataType' shape'
input', pos
edtPos', attentionMask
edtAttentionMask'))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\encoderOutput
encoderOutput -> forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (forall encoderOutput.
encoderOutput -> EncoderDecoderTransformerOutput' encoderOutput
EncoderDecoderTransformerOutput' encoderOutput
encoderOutput))
instance
( HasForward
sharedEmbedding
decoderInput
generatorDevice
embeddingOutput'
embeddingGeneratorOutputDevice',
embeddingOutput' ~ Tensor requiresGradient'' layout'' device'' dataType'' shape'',
HasForward
decoder
( embeddingOutput',
encoderOutput,
decoderPos,
decoderAttentionMask,
crossAttentionMask
)
embeddingGeneratorOutputDevice'
decoderOutput
decoderGeneratorOutputDevice,
HasForward
head
decoderOutput
decoderGeneratorOutputDevice
headOutput
generatorOutputDevice
) =>
HasForward
(GEncoderDecoderTransformer inputEmbedDim encoder decoder sharedEmbedding head)
(EncoderDecoderTransformerGenerationInput decoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask)
generatorDevice
(EncoderDecoderTransformerOutput headOutput encoderOutput)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> Generator generatorDevice
-> m (EncoderDecoderTransformerOutput headOutput encoderOutput,
Generator generatorOutputDevice)
forward GEncoderDecoderTransformer {sharedEmbedding
decoder
head
encoder
SDim inputEmbedDim
EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling :: EncoderDecoderTransformerHasEmbedScaling
edtHead :: head
edtSharedEmbedding :: sharedEmbedding
edtDecoder :: decoder
edtEncoder :: encoder
edtInputEmbedDim :: SDim inputEmbedDim
edtEmbedScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> EncoderDecoderTransformerHasEmbedScaling
edtHead :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> head
edtSharedEmbedding :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> sharedEmbedding
edtDecoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> decoder
edtEncoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> encoder
edtInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) encoder
decoder sharedEmbedding head.
GEncoderDecoderTransformer
inputEmbedDim encoder decoder sharedEmbedding head
-> SDim inputEmbedDim
..} EncoderDecoderTransformerGenerationInput {decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
edtGenerationCrossAttentionMask :: crossAttentionMask
edtGenerationDecoderAttentionMask :: decoderAttentionMask
edtGenerationDecoderPos :: decoderPos
edtGenerationEncoderOutput :: encoderOutput
edtGenerationDecoderInput :: decoderInput
edtGenerationCrossAttentionMask :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> crossAttentionMask
edtGenerationDecoderAttentionMask :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> decoderAttentionMask
edtGenerationDecoderPos :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> decoderPos
edtGenerationEncoderOutput :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> encoderOutput
edtGenerationDecoderInput :: forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
-> decoderInput
..} =
let Double
scaling :: Double = forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
edtInputEmbedDim
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn decoderInput
edtGenerationDecoderInput
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward sharedEmbedding
edtSharedEmbedding
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithoutEmbedScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
EncoderDecoderTransformerHasEmbedScaling
EncoderDecoderTransformerWithEmbedScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
)
EncoderDecoderTransformerHasEmbedScaling
edtEmbedScaling
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \Tensor requiresGradient'' layout'' device'' dataType'' shape''
decoderInput' ->
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$ forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward decoder
edtDecoder (Tensor requiresGradient'' layout'' device'' dataType'' shape''
decoderInput', encoderOutput
edtGenerationEncoderOutput, decoderPos
edtGenerationDecoderPos, decoderAttentionMask
edtGenerationDecoderAttentionMask, crossAttentionMask
edtGenerationCrossAttentionMask)
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward head
edtHead
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \headOutput
decoderOutput -> forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn (forall decoderOutput encoderOutput.
decoderOutput
-> encoderOutput
-> EncoderDecoderTransformerOutput decoderOutput encoderOutput
EncoderDecoderTransformerOutput headOutput
decoderOutput encoderOutput
edtGenerationEncoderOutput)
instance
( HasForward
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
(SimplifiedEncoderDecoderTransformerInput' input)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask)
generatorDevice',
HasForward
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
(SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask)
generatorDevice'
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorOutputDevice
) =>
HasForward
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask,
Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
model SimplifiedEncoderDecoderTransformerInput {input
decoderInput
sedtDecoderInput :: decoderInput
sedtInput :: input
sedtDecoderInput :: forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> decoderInput
sedtInput :: forall input decoderInput.
SimplifiedEncoderDecoderTransformerInput input decoderInput
-> input
..} =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
model SimplifiedEncoderDecoderTransformerInput' {sedtInput' :: input
sedtInput' = input
sedtInput})
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \SimplifiedEncoderDecoderTransformerOutput' {encoderOutput
inputPaddingMask
sedtInputPaddingMask' :: inputPaddingMask
sedtEncoderOutput' :: encoderOutput
sedtInputPaddingMask' :: forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> inputPaddingMask
sedtEncoderOutput' :: forall encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask
-> encoderOutput
..} ->
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$
SimplifiedEncoderDecoderTransformerGenerationInput
{ sedtGenerationDecoderInput :: decoderInput
sedtGenerationDecoderInput = decoderInput
sedtDecoderInput,
sedtGenerationEncoderOutput :: encoderOutput
sedtGenerationEncoderOutput = encoderOutput
sedtEncoderOutput',
sedtGenerationInputPaddingMask :: inputPaddingMask
sedtGenerationInputPaddingMask = inputPaddingMask
sedtInputPaddingMask'
}
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
model
instance
( HasForward
mkPaddingMask
input
generatorDevice
inputPaddingMask
generatorDevice,
HasForward
mkAttentionMask
inputPaddingMask
generatorDevice
attentionMask
generatorDevice,
HasForward
mkPos
input
generatorDevice
pos
generatorDevice,
HasForward
model
(EncoderDecoderTransformerInput' input pos attentionMask)
generatorDevice
(EncoderDecoderTransformerOutput' encoderOutput)
generatorOutputDevice
) =>
HasForward
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
(SimplifiedEncoderDecoderTransformerInput' input)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput' encoderOutput inputPaddingMask)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerInput' input
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerOutput'
encoderOutput inputPaddingMask,
Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer {mkPaddingMask
mkAttentionMask
mkPos
model
mkDecoderPos
mkCrossAttentionMask
mkDecoderAttentionMask
ShiftRight Int
sedtMkDecoderAttentionMask :: mkDecoderAttentionMask
sedtMkCrossAttentionMask :: mkCrossAttentionMask
sedtMkAttentionMask :: mkAttentionMask
sedtMkPaddingMask :: mkPaddingMask
sedtMkDecoderPos :: mkDecoderPos
sedtMkPos :: mkPos
sedtPaddingMaskShift :: ShiftRight Int
sedtDecoderInputShift :: ShiftRight Int
sedtModel :: model
sedtMkDecoderAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkDecoderAttentionMask
sedtMkCrossAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkCrossAttentionMask
sedtMkAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkAttentionMask
sedtMkPaddingMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkPaddingMask
sedtMkDecoderPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkDecoderPos
sedtMkPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkPos
sedtPaddingMaskShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShiftRight Int
sedtDecoderInputShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShiftRight Int
sedtModel :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> model
..} SimplifiedEncoderDecoderTransformerInput' {input
sedtInput' :: input
sedtInput' :: forall input.
SimplifiedEncoderDecoderTransformerInput' input -> input
..} =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPaddingMask
sedtMkPaddingMask input
sedtInput')
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \inputPaddingMask
inputPaddingMask ->
let pos :: IxStateT
m (Generator generatorDevice) (Generator generatorDevice) pos
pos = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPos
sedtMkPos forall a b. (a -> b) -> a -> b
$ input
sedtInput'
attentionMask :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
attentionMask
attentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkAttentionMask
sedtMkAttentionMask forall a b. (a -> b) -> a -> b
$ inputPaddingMask
inputPaddingMask
in forall input pos attentionMask.
input
-> pos
-> attentionMask
-> EncoderDecoderTransformerInput' input pos attentionMask
EncoderDecoderTransformerInput'
input
sedtInput'
forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m (Generator generatorDevice) (Generator generatorDevice) pos
pos
forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
attentionMask
attentionMask
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
sedtModel
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \EncoderDecoderTransformerOutput' {encoderOutput
edtEncoderOutput' :: encoderOutput
edtEncoderOutput' :: forall encoderOutput.
EncoderDecoderTransformerOutput' encoderOutput -> encoderOutput
..} ->
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$
SimplifiedEncoderDecoderTransformerOutput'
{ sedtEncoderOutput' :: encoderOutput
sedtEncoderOutput' = encoderOutput
edtEncoderOutput',
sedtInputPaddingMask' :: inputPaddingMask
sedtInputPaddingMask' = inputPaddingMask
inputPaddingMask
}
)
)
instance
( HasForward
mkPaddingMask
decoderInput
generatorDevice
decoderInputPaddingMask
generatorDevice,
HasForward
mkCrossAttentionMask
(rightShiftedDecoderInput, inputPaddingMask)
generatorDevice
crossAttentionMask
generatorDevice,
HasForward
mkDecoderAttentionMask
rightShiftedDecoderInputPaddingMask
generatorDevice
decoderAttentionMask
generatorDevice,
HasForward
(ShiftRight Int)
decoderInput
generatorDevice
rightShiftedDecoderInput
generatorDevice,
HasForward
(ShiftRight Int)
decoderInputPaddingMask
generatorDevice
rightShiftedDecoderInputPaddingMask
generatorDevice,
HasForward
mkDecoderPos
rightShiftedDecoderInput
generatorDevice
decoderPos
generatorDevice,
HasForward
model
(EncoderDecoderTransformerGenerationInput rightShiftedDecoderInput encoderOutput decoderPos decoderAttentionMask crossAttentionMask)
generatorDevice
(EncoderDecoderTransformerOutput decoderOutput encoderOutput)
generatorOutputDevice
) =>
HasForward
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
(SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask,
Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer {mkPaddingMask
mkCrossAttentionMask
mkDecoderAttentionMask
mkDecoderPos
model
mkPos
mkAttentionMask
ShiftRight Int
sedtMkDecoderAttentionMask :: mkDecoderAttentionMask
sedtMkCrossAttentionMask :: mkCrossAttentionMask
sedtMkAttentionMask :: mkAttentionMask
sedtMkPaddingMask :: mkPaddingMask
sedtMkDecoderPos :: mkDecoderPos
sedtMkPos :: mkPos
sedtPaddingMaskShift :: ShiftRight Int
sedtDecoderInputShift :: ShiftRight Int
sedtModel :: model
sedtMkDecoderAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkDecoderAttentionMask
sedtMkCrossAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkCrossAttentionMask
sedtMkAttentionMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkAttentionMask
sedtMkPaddingMask :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkPaddingMask
sedtMkDecoderPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkDecoderPos
sedtMkPos :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> mkPos
sedtPaddingMaskShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShiftRight Int
sedtDecoderInputShift :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> ShiftRight Int
sedtModel :: forall model mkPos mkDecoderPos mkPaddingMask mkAttentionMask
mkCrossAttentionMask mkDecoderAttentionMask.
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> model
..} SimplifiedEncoderDecoderTransformerGenerationInput {decoderInput
inputPaddingMask
encoderOutput
sedtGenerationInputPaddingMask :: inputPaddingMask
sedtGenerationEncoderOutput :: encoderOutput
sedtGenerationDecoderInput :: decoderInput
sedtGenerationInputPaddingMask :: forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> inputPaddingMask
sedtGenerationEncoderOutput :: forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> encoderOutput
sedtGenerationDecoderInput :: forall decoderInput encoderOutput inputPaddingMask.
SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
-> decoderInput
..} =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
( let rightShiftedDecoderInput :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
rightShiftedDecoderInput
rightShiftedDecoderInput = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward ShiftRight Int
sedtDecoderInputShift forall a b. (a -> b) -> a -> b
$ decoderInput
sedtGenerationDecoderInput
rightShiftedDecoderInputPaddingMask :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn decoderInput
sedtGenerationDecoderInput
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkPaddingMask
sedtMkPaddingMask
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward ShiftRight Int
sedtPaddingMaskShift
in (,)
forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
rightShiftedDecoderInput
rightShiftedDecoderInput
forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(rightShiftedDecoderInput
rightShiftedDecoderInput, rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask) ->
let decoderPos :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
decoderPos
decoderPos = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkDecoderPos
sedtMkDecoderPos forall a b. (a -> b) -> a -> b
$ rightShiftedDecoderInput
rightShiftedDecoderInput
crossAttentionMask :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
crossAttentionMask
crossAttentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkCrossAttentionMask
sedtMkCrossAttentionMask forall a b. (a -> b) -> a -> b
$ (rightShiftedDecoderInput
rightShiftedDecoderInput, inputPaddingMask
sedtGenerationInputPaddingMask)
decoderAttentionMask :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
decoderAttentionMask
decoderAttentionMask = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward mkDecoderAttentionMask
sedtMkDecoderAttentionMask forall a b. (a -> b) -> a -> b
$ rightShiftedDecoderInputPaddingMask
rightShiftedDecoderInputPaddingMask
in ( forall decoderInput encoderOutput decoderPos decoderAttentionMask
crossAttentionMask.
decoderInput
-> encoderOutput
-> decoderPos
-> decoderAttentionMask
-> crossAttentionMask
-> EncoderDecoderTransformerGenerationInput
decoderInput
encoderOutput
decoderPos
decoderAttentionMask
crossAttentionMask
EncoderDecoderTransformerGenerationInput
rightShiftedDecoderInput
rightShiftedDecoderInput
encoderOutput
sedtGenerationEncoderOutput
forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
decoderPos
decoderPos
forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
decoderAttentionMask
decoderAttentionMask
forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
crossAttentionMask
crossAttentionMask
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
sedtModel
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= ( \(EncoderDecoderTransformerOutput decoderOutput
decoderOutput encoderOutput
encoderOutput) ->
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall a b. (a -> b) -> a -> b
$ forall decoderOutput encoderOutput decoderInput inputPaddingMask.
decoderOutput
-> encoderOutput
-> decoderInput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
SimplifiedEncoderDecoderTransformerOutput decoderOutput
decoderOutput encoderOutput
encoderOutput decoderInput
sedtGenerationDecoderInput inputPaddingMask
sedtGenerationInputPaddingMask
)
)
instance
( HasForward
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorOutputDevice,
decoderInput
~ Tensor
targetGradient
targetLayout
targetDevice
targetDataType
(IndexDims ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)]) targetShape),
decoderOutput
~ Tensor
doGradient
doLayout
doDevice
doDataType
doShape,
logProbsShape ~ SoftmaxF ('SelectDim ('ByIndex 2)) doShape,
Catch logProbsShape,
unsqueezedTargetShape ~ UnsqueezeF ('SelectDim ('ByIndex 2)) targetShape,
Catch unsqueezedTargetShape,
gatheredLogProbsShape ~ GatherDimF ('SelectDim ('ByIndex 2)) unsqueezedTargetShape logProbsShape,
Catch gatheredLogProbsShape,
Catch (targetDataType <+> 'DataType 'Int64),
logLikelihoodShape ~ SqueezeDimF ('SelectDim ('ByIndex 2)) gatheredLogProbsShape,
Catch logLikelihoodShape,
MeanAllCheckF logLikelihoodShape,
loss
~ Tensor
(targetGradient <|> doGradient)
(targetLayout <+> doLayout)
(targetDevice <+> doDevice)
doDataType
('Shape '[]),
generatorOutputDevice ~ generatorDevice
) =>
HasForward
(GSimplifiedEncoderDecoderTransformer model mkPos mkDecoderPos mkPaddingMask mkAttentionMask mkCrossAttentionMask mkDecoderAttentionMask)
( SimplifiedEncoderDecoderTransformerTrainingInput
input
(Tensor targetGradient targetLayout targetDevice targetDataType targetShape)
)
generatorDevice
(SimplifiedEncoderDecoderTransformerTrainingOutput loss)
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
-> SimplifiedEncoderDecoderTransformerTrainingInput
input
(Tensor
targetGradient
targetLayout
targetDevice
targetDataType
targetShape)
-> Generator generatorDevice
-> m (SimplifiedEncoderDecoderTransformerTrainingOutput loss,
Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
eot SimplifiedEncoderDecoderTransformerTrainingInput {input
Tensor
targetGradient targetLayout targetDevice targetDataType targetShape
sedtTarget :: Tensor
targetGradient targetLayout targetDevice targetDataType targetShape
sedtTrainingInput :: input
sedtTarget :: forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> target
sedtTrainingInput :: forall input target.
SimplifiedEncoderDecoderTransformerTrainingInput input target
-> input
..} =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
targetGradient targetLayout targetDevice targetDataType targetShape
sedtTarget
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (indices :: Indices [IndexType (Index Nat)])
(requiresGradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> SIndices indices
-> m (Tensor
requiresGradient layout device dataType (IndexDims indices shape))
! forall (indexTypes :: [IndexType (Index Nat)]).
SList indexTypes -> SIndices ('Indices indexTypes)
SIndices (forall a. SIndexType 'SliceAll
SSliceAll forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a (n :: a). Sing n -> SIndexType ('SliceUpTo n)
SSliceUpTo (forall (index1 :: Nat).
KnownNat index1 =>
SIndex ('NegativeIndex index1)
SNegativeIndex @1) forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil))
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= (\Tensor
targetGradient
targetLayout
targetDevice
targetDataType
(IndexDims
('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
targetShape)
sedtDecoderInput -> forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GSimplifiedEncoderDecoderTransformer
model
mkPos
mkDecoderPos
mkPaddingMask
mkAttentionMask
mkCrossAttentionMask
mkDecoderAttentionMask
eot forall a b. (a -> b) -> a -> b
$ SimplifiedEncoderDecoderTransformerInput {sedtInput :: input
sedtInput = input
sedtTrainingInput, Tensor
targetGradient
targetLayout
targetDevice
targetDataType
(IndexDims
('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
targetShape)
sedtDecoderInput :: Tensor
targetGradient
targetLayout
targetDevice
targetDataType
(IndexDims
('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
targetShape)
sedtDecoderInput :: Tensor
targetGradient
targetLayout
targetDevice
targetDataType
(IndexDims
('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)])
targetShape)
sedtDecoderInput})
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderOutput
sedtDecoderOutput
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \Tensor doGradient doLayout doDevice doDataType doShape
logits -> do
Tensor
doGradient
doLayout
doDevice
doDataType
(SoftmaxF ('SelectDim ('ByIndex 2)) doShape)
logProbs <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor doGradient doLayout doDevice doDataType doShape
logits
Tensor
targetGradient
targetLayout
targetDevice
targetDataType
unsqueezedTargetShape
target' <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
targetGradient targetLayout targetDevice targetDataType targetShape
sedtTarget
Tensor
(Or (Gradient RequiresGradient) targetGradient doGradient)
(Unify (Layout LayoutType) targetLayout doLayout)
(Unify (Device (DeviceType Nat)) targetDevice doDevice)
doDataType
(GatherDimF
('SelectDim ('ByIndex 2))
unsqueezedTargetShape
(SoftmaxF ('SelectDim ('ByIndex 2)) doShape))
gatheredLogProbs <- forall (selectDim :: SelectDim (By Symbol Nat))
(indexGradient :: Gradient RequiresGradient)
(inputGradient :: Gradient RequiresGradient)
(indexLayout :: Layout LayoutType)
(inputLayout :: Layout LayoutType)
(indexDevice :: Device (DeviceType Nat))
(inputDevice :: Device (DeviceType Nat))
(indexDataType :: DataType DType) (inputDataType :: DataType DType)
(indexShape :: Shape [Dim (Name Symbol) (Size Nat)])
(inputShape :: Shape [Dim (Name Symbol) (Size Nat)])
(outputShape :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(MonadThrow m,
outputShape ~ GatherDimF selectDim indexShape inputShape,
Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) =>
SSelectDim selectDim
-> Tensor
indexGradient indexLayout indexDevice indexDataType indexShape
-> Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
-> m (Tensor
(indexGradient <|> inputGradient)
(indexLayout <+> inputLayout)
(indexDevice <+> inputDevice)
inputDataType
outputShape)
sGatherDim (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
targetGradient
targetLayout
targetDevice
targetDataType
unsqueezedTargetShape
target' Tensor
doGradient
doLayout
doDevice
doDataType
(SoftmaxF ('SelectDim ('ByIndex 2)) doShape)
logProbs
Tensor
(Or (Gradient RequiresGradient) targetGradient doGradient)
(Unify (Layout LayoutType) targetLayout doLayout)
(Unify (Device (DeviceType Nat)) targetDevice doDevice)
doDataType
(SqueezeDimF
('SelectDim ('ByIndex 2))
(GatherDimF
('SelectDim ('ByIndex 2))
unsqueezedTargetShape
(SoftmaxF ('SelectDim ('ByIndex 2)) doShape)))
logLikelihood <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SqueezeDimF selectDim shape,
Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sSqueezeDim (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
(Or (Gradient RequiresGradient) targetGradient doGradient)
(Unify (Layout LayoutType) targetLayout doLayout)
(Unify (Device (DeviceType Nat)) targetDevice doDevice)
doDataType
(GatherDimF
('SelectDim ('ByIndex 2))
unsqueezedTargetShape
(SoftmaxF ('SelectDim ('ByIndex 2)) doShape))
gatheredLogProbs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MeanAllCheckF shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType ('Shape '[])
meanAll Tensor
(Or (Gradient RequiresGradient) targetGradient doGradient)
(Unify (Layout LayoutType) targetLayout doLayout)
(Unify (Device (DeviceType Nat)) targetDevice doDevice)
doDataType
(SqueezeDimF
('SelectDim ('ByIndex 2))
(GatherDimF
('SelectDim ('ByIndex 2))
unsqueezedTargetShape
(SoftmaxF ('SelectDim ('ByIndex 2)) doShape)))
logLikelihood
)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall loss.
loss -> SimplifiedEncoderDecoderTransformerTrainingOutput loss
SimplifiedEncoderDecoderTransformerTrainingOutput