{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Transformer.GFeedForwardNetwork where
import Control.Monad.Indexed (ireturn, (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Activation (Gelu (..), GeluNew (..), Relu (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Dropout (Dropout (..))
import Torch.GraduallyTyped.NN.Linear (GLinear (..), GLinearF, linearSpec)
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), HasDropout (..), SHasBias (..), SHasDropout (..))
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, SShape (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
data
GGate
(layer0 :: Type)
(activation :: Type)
(layer1 :: Type)
where
GGate ::
forall layer0 activation layer1.
{
forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer0
gateLayer0 :: layer0,
forall layer0 activation layer1.
GGate layer0 activation layer1 -> activation
gateActivation :: activation,
forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer1
gateLayer1 :: layer1
} ->
GGate layer0 activation layer1
deriving stock (GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall layer0 activation layer1.
(Eq layer0, Eq activation, Eq layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
/= :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c/= :: forall layer0 activation layer1.
(Eq layer0, Eq activation, Eq layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
== :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c== :: forall layer0 activation layer1.
(Eq layer0, Eq activation, Eq layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
Eq, GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {layer0} {activation} {layer1}.
(Ord layer0, Ord activation, Ord layer1) =>
Eq (GGate layer0 activation layer1)
forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
min :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
$cmin :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
max :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
$cmax :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> GGate layer0 activation layer1
>= :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c>= :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
> :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c> :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
<= :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c<= :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
< :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
$c< :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Bool
compare :: GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
$ccompare :: forall layer0 activation layer1.
(Ord layer0, Ord activation, Ord layer1) =>
GGate layer0 activation layer1
-> GGate layer0 activation layer1 -> Ordering
Ord, Int -> GGate layer0 activation layer1 -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
Int -> GGate layer0 activation layer1 -> ShowS
forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
[GGate layer0 activation layer1] -> ShowS
forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
GGate layer0 activation layer1 -> String
showList :: [GGate layer0 activation layer1] -> ShowS
$cshowList :: forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
[GGate layer0 activation layer1] -> ShowS
show :: GGate layer0 activation layer1 -> String
$cshow :: forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
GGate layer0 activation layer1 -> String
showsPrec :: Int -> GGate layer0 activation layer1 -> ShowS
$cshowsPrec :: forall layer0 activation layer1.
(Show layer0, Show activation, Show layer1) =>
Int -> GGate layer0 activation layer1 -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall layer0 activation layer1 x.
Rep (GGate layer0 activation layer1) x
-> GGate layer0 activation layer1
forall layer0 activation layer1 x.
GGate layer0 activation layer1
-> Rep (GGate layer0 activation layer1) x
$cto :: forall layer0 activation layer1 x.
Rep (GGate layer0 activation layer1) x
-> GGate layer0 activation layer1
$cfrom :: forall layer0 activation layer1 x.
GGate layer0 activation layer1
-> Rep (GGate layer0 activation layer1) x
Generic)
type instance
ModelSpec (GGate layer0 activation layer1) =
GGate (ModelSpec layer0) (ModelSpec activation) (ModelSpec layer1)
instance
( HasInitialize layer0 generatorDevice layer0' generatorDevice0,
HasInitialize activation generatorDevice0 activation' generatorDevice1,
HasInitialize layer1 generatorDevice1 layer1' generatorOutputDevice
) =>
HasInitialize
(GGate layer0 activation layer1)
generatorDevice
(GGate layer0' activation' layer1')
generatorOutputDevice
instance
(HasStateDict layer0, HasStateDict activation, HasStateDict layer1) =>
HasStateDict (GGate layer0 activation layer1)
instance
( HasForward
layer0
(Tensor gradient layout device dataType shape)
generatorDevice
(Tensor gradient' layout' device' dataType' shape')
generatorDevice',
HasForward
activation
(Tensor gradient' layout' device' dataType' shape')
generatorDevice'
(Tensor gradient' layout' device' dataType' shape')
generatorDevice',
HasForward
layer1
(Tensor gradient layout device dataType shape)
generatorDevice'
(Tensor gradient' layout' device' dataType' shape')
generatorDevice'',
output ~ Tensor gradient' layout' device' dataType' shape',
generatorOutputDevice ~ generatorDevice''
) =>
HasForward
(GGate layer0 activation layer1)
(Tensor gradient layout device dataType shape)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GGate layer0 activation layer1
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GGate {layer0
activation
layer1
gateLayer1 :: layer1
gateActivation :: activation
gateLayer0 :: layer0
gateLayer1 :: forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer1
gateActivation :: forall layer0 activation layer1.
GGate layer0 activation layer1 -> activation
gateLayer0 :: forall layer0 activation layer1.
GGate layer0 activation layer1 -> layer0
..} Tensor gradient layout device dataType shape
input =
let activate :: Tensor gradient layout device dataType shape
-> IxStateT
m (Generator generatorDevice) (Generator generatorDevice') output
activate Tensor gradient layout device dataType shape
input' =
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor gradient layout device dataType shape
input'
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward layer0
gateLayer0
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activation
gateActivation
gate :: Tensor gradient layout device dataType shape
-> IxStateT
m
(Generator generatorDevice)
(Generator generatorOutputDevice)
output
gate Tensor gradient layout device dataType shape
input' = forall a. Num a => a -> a -> a
(*) forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> Tensor gradient layout device dataType shape
-> IxStateT
m (Generator generatorDevice) (Generator generatorDevice') output
activate Tensor gradient layout device dataType shape
input' forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> (forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward layer1
gateLayer1 forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
input')
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor gradient layout device dataType shape
input forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= Tensor gradient layout device dataType shape
-> IxStateT
m
(Generator generatorDevice)
(Generator generatorOutputDevice)
output
gate
data
GTransformerFeedForwardNetwork
(inputLayerNorm :: Type)
(inputTransformation :: Type)
(activation :: Type)
(activationDropout :: Type)
(outputProjection :: Type)
(outputDropout :: Type)
(outputLayerNorm :: Type)
where
GTransformerFeedForwardNetwork ::
forall inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm.
{
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> inputLayerNorm
ffnInputLayerNorm :: inputLayerNorm,
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> inputTransformation
ffnInputTransformation :: inputTransformation,
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> activation
ffnActivation :: activation,
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> activationDropout
ffnActivationDropout :: activationDropout,
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> outputProjection
ffnOutputProjection :: outputProjection,
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> outputDropout
ffnOutputDropout :: outputDropout,
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> outputLayerNorm
ffnOutputLayerNorm :: outputLayerNorm
} ->
GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm
deriving stock (GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Eq inputLayerNorm, Eq inputTransformation, Eq activation,
Eq activationDropout, Eq outputProjection, Eq outputDropout,
Eq outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
/= :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
$c/= :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Eq inputLayerNorm, Eq inputTransformation, Eq activation,
Eq activationDropout, Eq outputProjection, Eq outputDropout,
Eq outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
== :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
$c== :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Eq inputLayerNorm, Eq inputTransformation, Eq activation,
Eq activationDropout, Eq outputProjection, Eq outputDropout,
Eq outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
Eq, GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {inputLayerNorm} {inputTransformation} {activation}
{activationDropout} {outputProjection} {outputDropout}
{outputLayerNorm}.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
Eq
(GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm)
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Ordering
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
min :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
$cmin :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
max :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
$cmax :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
>= :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
$c>= :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
> :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
$c> :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
<= :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
$c<= :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
< :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
$c< :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Bool
compare :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Ordering
$ccompare :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Ord inputLayerNorm, Ord inputTransformation, Ord activation,
Ord activationDropout, Ord outputProjection, Ord outputDropout,
Ord outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Ordering
Ord, Int
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
Show activationDropout, Show outputProjection, Show outputDropout,
Show outputLayerNorm) =>
Int
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> ShowS
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
Show activationDropout, Show outputProjection, Show outputDropout,
Show outputLayerNorm) =>
[GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm]
-> ShowS
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
Show activationDropout, Show outputProjection, Show outputDropout,
Show outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> String
showList :: [GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm]
-> ShowS
$cshowList :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
Show activationDropout, Show outputProjection, Show outputDropout,
Show outputLayerNorm) =>
[GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm]
-> ShowS
show :: GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> String
$cshow :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
Show activationDropout, Show outputProjection, Show outputDropout,
Show outputLayerNorm) =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> String
showsPrec :: Int
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> ShowS
$cshowsPrec :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
(Show inputLayerNorm, Show inputTransformation, Show activation,
Show activationDropout, Show outputProjection, Show outputDropout,
Show outputLayerNorm) =>
Int
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm x.
Rep
(GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm)
x
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm x.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Rep
(GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm)
x
$cto :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm x.
Rep
(GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm)
x
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
$cfrom :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm x.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Rep
(GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm)
x
Generic)
type instance
ModelSpec (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm) =
GTransformerFeedForwardNetwork (ModelSpec inputLayerNorm) (ModelSpec inputTransformation) (ModelSpec activation) (ModelSpec activationDropout) (ModelSpec outputProjection) (ModelSpec outputDropout) (ModelSpec outputLayerNorm)
type family
GTransformerFeedForwardNetworkF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout) ::
Type
where
GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout =
GTransformerFeedForwardNetwork
(FFNInputLayerNormF style gradient device dataType queryEmbedDim)
(FFNInputTransformationF style gradient device dataType queryEmbedDim ffnDim)
(FFNActivationF style)
(FFNActivationDropoutF style hasDropout)
(FFNOutputProjectionF style gradient device dataType queryEmbedDim ffnDim)
(FFNOutputDropoutF style hasDropout)
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim)
type family
FFNInputLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithoutBias gradient device dataType ('Shape '[queryEmbedDim]))
FFNInputLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
FFNInputLayerNormF 'T5 gradient device dataType queryEmbedDim
FFNInputLayerNormF 'BART _ _ _ _ =
()
FFNInputLayerNormF 'MBART gradient device dataType queryEmbedDim =
FFNInputLayerNormF 'BART gradient device dataType queryEmbedDim
FFNInputLayerNormF 'Pegasus gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
FFNInputLayerNormF 'BERT _ _ _ _ =
()
FFNInputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim =
FFNInputLayerNormF 'BERT gradient device dataType queryEmbedDim
type family
FFNInputTransformationF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
FFNInputTransformationF 'T5 gradient device dataType queryEmbedDim ffnDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim)
FFNInputTransformationF 'ByT5 gradient device dataType queryEmbedDim ffnDim =
GGate
(NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim))
GeluNew
(NamedModel (GLinearF 'WithoutBias gradient device dataType queryEmbedDim ffnDim))
FFNInputTransformationF _ gradient device dataType queryEmbedDim ffnDim =
NamedModel (GLinearF 'WithBias gradient device dataType queryEmbedDim ffnDim)
type family
FFNActivationF
(style :: TransformerStyle) ::
Type
where
FFNActivationF 'T5 = Relu
FFNActivationF 'ByT5 = GeluNew
FFNActivationF 'BART = Gelu
FFNActivationF 'MBART = Gelu
FFNActivationF 'Pegasus = Relu
FFNActivationF 'BERT = Gelu
FFNActivationF 'RoBERTa = Gelu
type family
FFNActivationDropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
FFNActivationDropoutF 'T5 'WithDropout = Dropout
FFNActivationDropoutF 'ByT5 hasDropout = FFNActivationDropoutF 'T5 hasDropout
FFNActivationDropoutF 'BART 'WithDropout = Dropout
FFNActivationDropoutF 'MBART hasDropout = FFNActivationDropoutF 'BART hasDropout
FFNActivationDropoutF 'Pegasus hasDropout = FFNActivationDropoutF 'BART hasDropout
FFNActivationDropoutF 'BERT _ = ()
FFNActivationDropoutF 'RoBERTa hasDropout = FFNActivationDropoutF 'BERT hasDropout
FFNActivationDropoutF _ 'WithoutDropout = ()
type family
FFNOutputProjectionF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType ffnDim queryEmbedDim)
FFNOutputProjectionF 'ByT5 gradient device dataType queryEmbedDim ffnDim =
FFNOutputProjectionF 'T5 gradient device dataType queryEmbedDim ffnDim
FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim =
NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim)
FFNOutputProjectionF 'MBART gradient device dataType queryEmbedDim ffnDim =
FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim
FFNOutputProjectionF 'Pegasus gradient device dataType queryEmbedDim ffnDim =
FFNOutputProjectionF 'BART gradient device dataType queryEmbedDim ffnDim
FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim =
NamedModel (GLinearF 'WithBias gradient device dataType ffnDim queryEmbedDim)
FFNOutputProjectionF 'RoBERTa gradient device dataType queryEmbedDim ffnDim =
FFNOutputProjectionF 'BERT gradient device dataType queryEmbedDim ffnDim
type family
FFNOutputDropoutF
(style :: TransformerStyle)
(hasDropout :: HasDropout) ::
Type
where
FFNOutputDropoutF _ 'WithDropout = Dropout
FFNOutputDropoutF _ 'WithoutDropout = ()
type family
FFNOutputLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
FFNOutputLayerNormF 'T5 _ _ _ _ =
()
FFNOutputLayerNormF 'ByT5 gradient device dataType queryEmbedDim =
FFNOutputLayerNormF 'T5 gradient device dataType queryEmbedDim
FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
FFNOutputLayerNormF 'MBART gradient device dataType queryEmbedDim =
FFNOutputLayerNormF 'BART gradient device dataType queryEmbedDim
FFNOutputLayerNormF 'Pegasus _ _ _ _ =
()
FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim =
NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[queryEmbedDim]))
FFNOutputLayerNormF 'RoBERTa gradient device dataType queryEmbedDim =
FFNOutputLayerNormF 'BERT gradient device dataType queryEmbedDim
transformerFeedForwardNetworkSpec ::
forall style gradient device dataType queryEmbedDim ffnDim hasDropout.
STransformerStyle style ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim queryEmbedDim ->
SDim ffnDim ->
SHasDropout hasDropout ->
Double ->
Double ->
ModelSpec (GTransformerFeedForwardNetworkF style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec :: forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(queryEmbedDim :: Dim (Name Symbol) (Size Nat))
(ffnDim :: Dim (Name Symbol) (Size Nat))
(hasDropout :: HasDropout).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim queryEmbedDim
-> SDim ffnDim
-> SHasDropout hasDropout
-> Double
-> Double
-> ModelSpec
(GTransformerFeedForwardNetworkF
style gradient device dataType queryEmbedDim ffnDim hasDropout)
transformerFeedForwardNetworkSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim SHasDropout hasDropout
hasDropout Double
dropoutP Double
eps =
let inputLayerNormSpec :: STransformerStyle style
-> ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim)
inputLayerNormSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"layer_norm." LayerNormSpec
'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
inputLayerNormSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"layer_norm." LayerNormSpec
'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec
inputLayerNormSpec STransformerStyle style
SBART = ()
inputLayerNormSpec STransformerStyle style
SMBART = ()
inputLayerNormSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_layer_norm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
inputLayerNormSpec STransformerStyle style
SBERT = ()
inputLayerNormSpec STransformerStyle style
SRoBERTa = ()
inputLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
inputTransformationSpec :: STransformerStyle style
-> ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim)
inputTransformationSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wi." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
weightSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
inputTransformationSpec STransformerStyle style
SByT5 =
forall layer0 activation layer1.
layer0 -> activation -> layer1 -> GGate layer0 activation layer1
GGate
(forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wi_0." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
weightSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim)
GeluNew
GeluNew
(forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wi_1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
weightSpecWithoutBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim)
inputTransformationSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
inputTransformationSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
inputTransformationSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc1." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
inputTransformationSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"intermediate.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
inputTransformationSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"intermediate.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim queryEmbedDim
queryEmbedDim SDim ffnDim
ffnDim
inputTransformationSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
activationSpec :: STransformerStyle style -> ModelSpec (FFNActivationF style)
activationSpec :: STransformerStyle style -> ModelSpec (FFNActivationF style)
activationSpec STransformerStyle style
ST5 = Relu
Relu
activationSpec STransformerStyle style
SByT5 = GeluNew
GeluNew
activationSpec STransformerStyle style
SBART = Gelu
Gelu
activationSpec STransformerStyle style
SMBART = Gelu
Gelu
activationSpec STransformerStyle style
SPegasus = Relu
Relu
activationSpec STransformerStyle style
SBERT = Gelu
Gelu
activationSpec STransformerStyle style
SRoBERTa = Gelu
Gelu
activationSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
activationDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNActivationDropoutF style hasDropout)
activationDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
activationDropoutSpec STransformerStyle style
ST5 SHasDropout hasDropout
SWithoutDropout = ()
activationDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
activationDropoutSpec STransformerStyle style
SByT5 SHasDropout hasDropout
SWithoutDropout = ()
activationDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
activationDropoutSpec STransformerStyle style
SBART SHasDropout hasDropout
SWithoutDropout = ()
activationDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
activationDropoutSpec STransformerStyle style
SMBART SHasDropout hasDropout
SWithoutDropout = ()
activationDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
activationDropoutSpec STransformerStyle style
SPegasus SHasDropout hasDropout
SWithoutDropout = ()
activationDropoutSpec STransformerStyle style
SBERT SHasDropout hasDropout
_ = ()
activationDropoutSpec STransformerStyle style
SRoBERTa SHasDropout hasDropout
_ = ()
activationDropoutSpec STransformerStyle style
SGPT2 SHasDropout hasDropout
_ = forall a. HasCallStack => a
undefined
outputProjectionSpec :: STransformerStyle style
-> ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim)
outputProjectionSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wo." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
weightSpecWithoutBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
outputProjectionSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"DenseReluDense.wo." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
weightSpecWithoutBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
outputProjectionSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc2." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
outputProjectionSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc2." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
outputProjectionSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"fc2." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
outputProjectionSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
outputProjectionSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.dense." forall a b. (a -> b) -> a -> b
$ forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias SDim ffnDim
ffnDim SDim queryEmbedDim
queryEmbedDim
outputProjectionSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
outputDropoutSpec :: STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNOutputDropoutF style hasDropout)
outputDropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithDropout = Double -> Dropout
Dropout Double
dropoutP
outputDropoutSpec STransformerStyle style
_ SHasDropout hasDropout
SWithoutDropout = ()
outputLayerNormSpec :: STransformerStyle style
-> ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim)
outputLayerNormSpec STransformerStyle style
ST5 = ()
outputLayerNormSpec STransformerStyle style
SByT5 = ()
outputLayerNormSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_layer_norm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
outputLayerNormSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_layer_norm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
outputLayerNormSpec STransformerStyle style
SPegasus = ()
outputLayerNormSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.LayerNorm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
outputLayerNormSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"output.LayerNorm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec
outputLayerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
inputLayerNorm
-> inputTransformation
-> activation
-> activationDropout
-> outputProjection
-> outputDropout
-> outputLayerNorm
-> GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
GTransformerFeedForwardNetwork
(STransformerStyle style
-> ModelSpec
(FFNInputLayerNormF style gradient device dataType queryEmbedDim)
inputLayerNormSpec STransformerStyle style
style)
(STransformerStyle style
-> ModelSpec
(FFNInputTransformationF
style gradient device dataType queryEmbedDim ffnDim)
inputTransformationSpec STransformerStyle style
style)
(STransformerStyle style -> ModelSpec (FFNActivationF style)
activationSpec STransformerStyle style
style)
(STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNActivationDropoutF style hasDropout)
activationDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
(STransformerStyle style
-> ModelSpec
(FFNOutputProjectionF
style gradient device dataType queryEmbedDim ffnDim)
outputProjectionSpec STransformerStyle style
style)
(STransformerStyle style
-> SHasDropout hasDropout
-> ModelSpec (FFNOutputDropoutF style hasDropout)
outputDropoutSpec STransformerStyle style
style SHasDropout hasDropout
hasDropout)
(STransformerStyle style
-> ModelSpec
(FFNOutputLayerNormF style gradient device dataType queryEmbedDim)
outputLayerNormSpec STransformerStyle style
style)
where
weightSpecWithoutBias ::
forall inputDim outputDim.
SDim inputDim ->
SDim outputDim ->
ModelSpec
( GLinear
(NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
(NamedModel ())
)
weightSpecWithoutBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel ()))
weightSpecWithoutBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType
weightSpecWithBias ::
forall inputDim outputDim.
SDim inputDim ->
SDim outputDim ->
ModelSpec
( GLinear
(NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])))
(NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
)
weightSpecWithBias :: forall (inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinear
(NamedModel
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])))
(NamedModel
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))))
weightSpecWithBias = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType
layerNormWithoutBiasSpec :: LayerNormSpec
'WithoutBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithoutBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
layerNormWithBiasSpec :: LayerNormSpec
'WithBias gradient device dataType ('Shape '[queryEmbedDim])
layerNormWithBiasSpec = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim queryEmbedDim
queryEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
instance
( HasInitialize inputLayerNorm generatorDevice inputLayerNorm' generatorDevice0,
HasInitialize inputTransformation generatorDevice0 inputTransformation' generatorDevice1,
HasInitialize activation generatorDevice1 activation' generatorDevice2,
HasInitialize activationDropout generatorDevice2 activationDropout' generatorDevice3,
HasInitialize outputProjection generatorDevice3 outputProjection' generatorDevice4,
HasInitialize outputDropout generatorDevice4 outputDropout' generatorDevice5,
HasInitialize outputLayerNorm generatorDevice5 outputLayerNorm' generatorOutputDevice
) =>
HasInitialize
(GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm)
generatorDevice
(GTransformerFeedForwardNetwork inputLayerNorm' inputTransformation' activation' activationDropout' outputProjection' outputDropout' outputLayerNorm')
generatorOutputDevice
instance
( HasStateDict inputLayerNorm,
HasStateDict inputTransformation,
HasStateDict activation,
HasStateDict activationDropout,
HasStateDict outputProjection,
HasStateDict outputDropout,
HasStateDict outputLayerNorm
) =>
HasStateDict (GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm)
instance
( HasForward
inputLayerNorm
(Tensor queryGradient queryLayout queryDevice queryDataType queryShape)
generatorDevice
tensor0
generatorDevice0,
HasForward
inputTransformation
tensor0
generatorDevice0
tensor1
generatorDevice1,
HasForward
activation
tensor1
generatorDevice1
tensor2
generatorDevice2,
HasForward
activationDropout
tensor2
generatorDevice2
tensor3
generatorDevice3,
HasForward
outputProjection
tensor3
generatorDevice3
tensor4
generatorDevice4,
HasForward
outputDropout
tensor4
generatorDevice4
(Tensor queryGradient5 queryLayout5 queryDevice5 queryDataType5 queryShape5)
generatorDevice5,
HasForward
outputLayerNorm
(Tensor (queryGradient <|> queryGradient5) (queryLayout <+> queryLayout5) (queryDevice <+> queryDevice5) (queryDataType <+> queryDataType5) (BroadcastShapesF queryShape queryShape5))
generatorDevice5
output
generatorOutputDevice,
Catch (BroadcastShapesF queryShape queryShape5)
) =>
HasForward
(GTransformerFeedForwardNetwork inputLayerNorm inputTransformation activation activationDropout outputProjection outputDropout outputLayerNorm)
(Tensor queryGradient queryLayout queryDevice queryDataType queryShape)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> Tensor
queryGradient queryLayout queryDevice queryDataType queryShape
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GTransformerFeedForwardNetwork {inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
ffnOutputLayerNorm :: outputLayerNorm
ffnOutputDropout :: outputDropout
ffnOutputProjection :: outputProjection
ffnActivationDropout :: activationDropout
ffnActivation :: activation
ffnInputTransformation :: inputTransformation
ffnInputLayerNorm :: inputLayerNorm
ffnOutputLayerNorm :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> outputLayerNorm
ffnOutputDropout :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> outputDropout
ffnOutputProjection :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> outputProjection
ffnActivationDropout :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> activationDropout
ffnActivation :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> activation
ffnInputTransformation :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> inputTransformation
ffnInputLayerNorm :: forall inputLayerNorm inputTransformation activation
activationDropout outputProjection outputDropout outputLayerNorm.
GTransformerFeedForwardNetwork
inputLayerNorm
inputTransformation
activation
activationDropout
outputProjection
outputDropout
outputLayerNorm
-> inputLayerNorm
..} Tensor
queryGradient queryLayout queryDevice queryDataType queryShape
query =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor
queryGradient queryLayout queryDevice queryDataType queryShape
query
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward inputLayerNorm
ffnInputLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward inputTransformation
ffnInputTransformation
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activation
ffnActivation
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activationDropout
ffnActivationDropout
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outputProjection
ffnOutputProjection
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outputDropout
ffnOutputDropout
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor
queryGradient queryLayout queryDevice queryDataType queryShape
query forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
`add`)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward outputLayerNorm
ffnOutputLayerNorm