{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Transformer.GLMHead where
import Control.Monad.Indexed (IxPointed (..), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (fromSing))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType, DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Activation (Gelu (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Linear (GLinearF, linearSpec)
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), SHasBias (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, SName (..), SShape (..), SSize (..), Shape (..), Size (..), pattern (:&:))
import Torch.GraduallyTyped.Tensor.Creation (sZeros)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add, mulScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
data LMHeadHasScaling
= LMHeadWithScaling
| LMHeadWithoutScaling
deriving stock (LMHeadHasScaling -> LMHeadHasScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c/= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
== :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c== :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
Eq, Eq LMHeadHasScaling
LMHeadHasScaling -> LMHeadHasScaling -> Bool
LMHeadHasScaling -> LMHeadHasScaling -> Ordering
LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
$cmin :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
max :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
$cmax :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
>= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c>= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
> :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c> :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
<= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c<= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
< :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c< :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
compare :: LMHeadHasScaling -> LMHeadHasScaling -> Ordering
$ccompare :: LMHeadHasScaling -> LMHeadHasScaling -> Ordering
Ord, Int -> LMHeadHasScaling -> ShowS
[LMHeadHasScaling] -> ShowS
LMHeadHasScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMHeadHasScaling] -> ShowS
$cshowList :: [LMHeadHasScaling] -> ShowS
show :: LMHeadHasScaling -> String
$cshow :: LMHeadHasScaling -> String
showsPrec :: Int -> LMHeadHasScaling -> ShowS
$cshowsPrec :: Int -> LMHeadHasScaling -> ShowS
Show, forall x. Rep LMHeadHasScaling x -> LMHeadHasScaling
forall x. LMHeadHasScaling -> Rep LMHeadHasScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep LMHeadHasScaling x -> LMHeadHasScaling
$cfrom :: forall x. LMHeadHasScaling -> Rep LMHeadHasScaling x
Generic)
type instance ModelSpec LMHeadHasScaling = LMHeadHasScaling
instance HasInitialize LMHeadHasScaling generatorDevice LMHeadHasScaling generatorDevice where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec LMHeadHasScaling
-> Generator generatorDevice
-> m (LMHeadHasScaling, Generator generatorDevice)
initialize ModelSpec LMHeadHasScaling
hasScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec LMHeadHasScaling
hasScaling, Generator generatorDevice
g)
instance HasStateDict LMHeadHasScaling where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec LMHeadHasScaling -> StateDictKey -> m LMHeadHasScaling
fromStateDict ModelSpec LMHeadHasScaling
hasScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec LMHeadHasScaling
hasScaling
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> LMHeadHasScaling -> m ()
toStateDict StateDictKey
_ LMHeadHasScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
data
GLMHead
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(dense :: Type)
(activation :: Type)
(layerNorm :: Type)
(decoder :: Type)
(bias :: Type)
where
GLMHead ::
forall inputEmbedDim dense activation layerNorm decoder bias.
{
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> SDim inputEmbedDim
lmHeadInputEmbedDim :: SDim inputEmbedDim,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> dense
lmHeadDense :: dense,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> activation
lmHeadActivation :: activation,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> layerNorm
lmHeadLayerNorm :: layerNorm,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> decoder
lmHeadDecoder :: decoder,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> bias
lmHeadBias :: bias,
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> LMHeadHasScaling
lmHeadHasScaling :: LMHeadHasScaling
} ->
GLMHead inputEmbedDim dense activation layerNorm decoder bias
deriving stock (Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
Show bias) =>
Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
Show bias) =>
[GLMHead inputEmbedDim dense activation layerNorm decoder bias]
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
Show bias) =>
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> String
showList :: [GLMHead inputEmbedDim dense activation layerNorm decoder bias]
-> ShowS
$cshowList :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
Show bias) =>
[GLMHead inputEmbedDim dense activation layerNorm decoder bias]
-> ShowS
show :: GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> String
$cshow :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
Show bias) =>
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> String
showsPrec :: Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
$cshowsPrec :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
Show bias) =>
Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias x.
Rep
(GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias x.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> Rep
(GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
$cto :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias x.
Rep
(GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
$cfrom :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias x.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> Rep
(GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
Generic)
type instance
ModelSpec (GLMHead inputEmbedDim dense activation layerNorm decoder bias) =
GLMHead inputEmbedDim (ModelSpec dense) (ModelSpec activation) (ModelSpec layerNorm) (ModelSpec decoder) (ModelSpec bias)
data GBias (bias :: Type) where
GBias :: forall bias. bias -> GBias bias
deriving stock (GBias bias -> GBias bias -> Bool
forall bias. Eq bias => GBias bias -> GBias bias -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: GBias bias -> GBias bias -> Bool
$c/= :: forall bias. Eq bias => GBias bias -> GBias bias -> Bool
== :: GBias bias -> GBias bias -> Bool
$c== :: forall bias. Eq bias => GBias bias -> GBias bias -> Bool
Eq, GBias bias -> GBias bias -> Bool
GBias bias -> GBias bias -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {bias}. Ord bias => Eq (GBias bias)
forall bias. Ord bias => GBias bias -> GBias bias -> Bool
forall bias. Ord bias => GBias bias -> GBias bias -> Ordering
forall bias. Ord bias => GBias bias -> GBias bias -> GBias bias
min :: GBias bias -> GBias bias -> GBias bias
$cmin :: forall bias. Ord bias => GBias bias -> GBias bias -> GBias bias
max :: GBias bias -> GBias bias -> GBias bias
$cmax :: forall bias. Ord bias => GBias bias -> GBias bias -> GBias bias
>= :: GBias bias -> GBias bias -> Bool
$c>= :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
> :: GBias bias -> GBias bias -> Bool
$c> :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
<= :: GBias bias -> GBias bias -> Bool
$c<= :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
< :: GBias bias -> GBias bias -> Bool
$c< :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
compare :: GBias bias -> GBias bias -> Ordering
$ccompare :: forall bias. Ord bias => GBias bias -> GBias bias -> Ordering
Ord, Int -> GBias bias -> ShowS
forall bias. Show bias => Int -> GBias bias -> ShowS
forall bias. Show bias => [GBias bias] -> ShowS
forall bias. Show bias => GBias bias -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GBias bias] -> ShowS
$cshowList :: forall bias. Show bias => [GBias bias] -> ShowS
show :: GBias bias -> String
$cshow :: forall bias. Show bias => GBias bias -> String
showsPrec :: Int -> GBias bias -> ShowS
$cshowsPrec :: forall bias. Show bias => Int -> GBias bias -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall bias x. Rep (GBias bias) x -> GBias bias
forall bias x. GBias bias -> Rep (GBias bias) x
$cto :: forall bias x. Rep (GBias bias) x -> GBias bias
$cfrom :: forall bias x. GBias bias -> Rep (GBias bias) x
Generic)
type instance ModelSpec (GBias bias) = GBias (ModelSpec bias)
type family
GLMHeadF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
GLMHeadF style gradient device dataType inputEmbedDim vocabDim =
GLMHead
inputEmbedDim
(LMHeadDenseF style gradient device dataType inputEmbedDim)
(LMHeadActivationF style)
(LMHeadLayerNormF style gradient device dataType inputEmbedDim)
(LMHeadDecoderF style gradient device dataType inputEmbedDim vocabDim)
(LMHeadBiasF style gradient device dataType vocabDim)
type family
LMHeadDenseF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
LMHeadDenseF 'T5 _ _ _ _ =
()
LMHeadDenseF 'ByT5 gradient device dataType inputEmbedDim =
LMHeadDenseF 'T5 gradient device dataType inputEmbedDim
LMHeadDenseF 'BART _ _ _ _ =
()
LMHeadDenseF 'MBART gradient device dataType inputEmbedDim =
LMHeadDenseF 'BART gradient device dataType inputEmbedDim
LMHeadDenseF 'Pegasus gradient device dataType inputEmbedDim =
LMHeadDenseF 'BART gradient device dataType inputEmbedDim
LMHeadDenseF 'BERT gradient device dataType inputEmbedDim =
NamedModel (GLinearF 'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
LMHeadDenseF 'RoBERTa gradient device dataType inputEmbedDim =
LMHeadDenseF 'BERT gradient device dataType inputEmbedDim
type family
LMHeadActivationF
(style :: TransformerStyle) ::
Type
where
LMHeadActivationF 'T5 = ()
LMHeadActivationF 'ByT5 = LMHeadActivationF 'T5
LMHeadActivationF 'BART = ()
LMHeadActivationF 'MBART = LMHeadActivationF 'BART
LMHeadActivationF 'Pegasus = LMHeadActivationF 'BART
LMHeadActivationF 'BERT = Gelu
LMHeadActivationF 'RoBERTa = LMHeadActivationF 'BERT
type family
LMHeadLayerNormF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
LMHeadLayerNormF 'T5 _ _ _ _ = ()
LMHeadLayerNormF 'ByT5 gradient device dataType inputEmbedDim =
LMHeadLayerNormF 'T5 gradient device dataType inputEmbedDim
LMHeadLayerNormF 'BART _ _ _ _ = ()
LMHeadLayerNormF 'MBART gradient device dataType inputEmbedDim =
LMHeadLayerNormF 'BART gradient device dataType inputEmbedDim
LMHeadLayerNormF 'Pegasus gradient device dataType inputEmbedDim =
LMHeadLayerNormF 'BART gradient device dataType inputEmbedDim
LMHeadLayerNormF 'BERT gradient device dataType inputEmbedDim =
NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
LMHeadLayerNormF 'RoBERTa gradient device dataType inputEmbedDim =
LMHeadLayerNormF 'BERT gradient device dataType inputEmbedDim
type family
LMHeadDecoderF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
LMHeadDecoderF 'T5 gradient device dataType inputEmbedDim vocabDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType inputEmbedDim vocabDim)
LMHeadDecoderF 'ByT5 gradient device dataType inputEmbedDim vocabDim =
LMHeadDecoderF 'T5 gradient device dataType inputEmbedDim vocabDim
LMHeadDecoderF 'BART gradient device dataType inputEmbedDim vocabDim =
NamedModel (GLinearF 'WithoutBias gradient device dataType inputEmbedDim vocabDim)
LMHeadDecoderF 'MBART gradient device dataType inputEmbedDim vocabDim =
LMHeadDecoderF 'BART gradient device dataType inputEmbedDim vocabDim
LMHeadDecoderF 'Pegasus gradient device dataType inputEmbedDim vocabDim =
LMHeadDecoderF 'BART gradient device dataType inputEmbedDim vocabDim
LMHeadDecoderF 'BERT gradient device dataType inputEmbedDim vocabDim =
NamedModel (GLinearF 'WithBias gradient device dataType inputEmbedDim vocabDim)
LMHeadDecoderF 'RoBERTa gradient device dataType inputEmbedDim vocabDim =
LMHeadDecoderF 'BERT gradient device dataType inputEmbedDim vocabDim
type family
LMHeadBiasF
(style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(vocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
LMHeadBiasF 'T5 _ _ _ _ =
GBias ()
LMHeadBiasF 'ByT5 gradient device dataType vocabDim =
LMHeadBiasF 'T5 gradient device dataType vocabDim
LMHeadBiasF 'BART gradient device dataType vocabDim =
GBias (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])))
LMHeadBiasF 'MBART gradient device dataType vocabDim =
LMHeadBiasF 'BART gradient device dataType vocabDim
LMHeadBiasF 'Pegasus gradient device dataType vocabDim =
LMHeadBiasF 'BART gradient device dataType vocabDim
LMHeadBiasF 'BERT _ _ _ _ =
GBias ()
LMHeadBiasF 'RoBERTa gradient device dataType vocabDim =
LMHeadBiasF 'BERT gradient device dataType vocabDim
lmHeadSpec ::
forall style gradient device dataType inputEmbedDim vocabDim.
STransformerStyle style ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim inputEmbedDim ->
SDim vocabDim ->
Double ->
ModelSpec (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec :: forall (style :: TransformerStyle)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputEmbedDim :: Dim (Name Symbol) (Size Nat))
(vocabDim :: Dim (Name Symbol) (Size Nat)).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputEmbedDim
-> SDim vocabDim
-> Double
-> ModelSpec
(GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim Double
eps =
let denseSpec :: STransformerStyle style
-> ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim)
denseSpec STransformerStyle style
ST5 = ()
denseSpec STransformerStyle style
SByT5 = ()
denseSpec STransformerStyle style
SBART = ()
denseSpec STransformerStyle style
SMBART = ()
denseSpec STransformerStyle style
SPegasus = ()
denseSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"transform.dense." ModelSpec
(GLinearF
'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
linearSpec'
denseSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"dense." ModelSpec
(GLinearF
'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
linearSpec'
denseSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
activationSpec :: STransformerStyle style -> ModelSpec (LMHeadActivationF style)
activationSpec :: STransformerStyle style -> ModelSpec (LMHeadActivationF style)
activationSpec STransformerStyle style
ST5 = ()
activationSpec STransformerStyle style
SByT5 = ()
activationSpec STransformerStyle style
SBART = ()
activationSpec STransformerStyle style
SMBART = ()
activationSpec STransformerStyle style
SPegasus = ()
activationSpec STransformerStyle style
SBERT = Gelu
Gelu
activationSpec STransformerStyle style
SRoBERTa = Gelu
Gelu
activationSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
layerNormSpec :: STransformerStyle style
-> ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim)
layerNormSpec STransformerStyle style
ST5 = ()
layerNormSpec STransformerStyle style
SByT5 = ()
layerNormSpec STransformerStyle style
SBART = ()
layerNormSpec STransformerStyle style
SMBART = ()
layerNormSpec STransformerStyle style
SPegasus = ()
layerNormSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"transform.LayerNorm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec'
layerNormSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"layer_norm." LayerNormSpec
'WithBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec'
layerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
decoderSpec :: STransformerStyle style
-> ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim)
decoderSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty ModelSpec
(GLinearF
'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
decoderSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty ModelSpec
(GLinearF
'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
decoderSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." ModelSpec
(GLinearF
'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
decoderSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." ModelSpec
(GLinearF
'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
decoderSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." ModelSpec
(GLinearF
'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
decoderSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." ModelSpec
(GLinearF
'WithBias gradient device dataType inputEmbedDim vocabDim)
linearWithBiasSpec'
decoderSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." ModelSpec
(GLinearF
'WithBias gradient device dataType inputEmbedDim vocabDim)
linearWithBiasSpec'
decoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
biasSpec :: STransformerStyle style
-> ModelSpec (LMHeadBiasF style gradient device dataType vocabDim)
biasSpec STransformerStyle style
ST5 = forall bias. bias -> GBias bias
GBias ()
biasSpec STransformerStyle style
SByT5 = forall bias. bias -> GBias bias
GBias ()
biasSpec STransformerStyle style
SBART = forall bias. bias -> GBias bias
GBias (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_logits_bias" TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec')
biasSpec STransformerStyle style
SMBART = forall bias. bias -> GBias bias
GBias (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_logits_bias" TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec')
biasSpec STransformerStyle style
SPegasus = forall bias. bias -> GBias bias
GBias (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_logits_bias" TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec')
biasSpec STransformerStyle style
SBERT = forall bias. bias -> GBias bias
GBias ()
biasSpec STransformerStyle style
SRoBERTa = forall bias. bias -> GBias bias
GBias ()
biasSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
scalingSpec :: STransformerStyle style -> LMHeadHasScaling
scalingSpec :: STransformerStyle style -> LMHeadHasScaling
scalingSpec STransformerStyle style
ST5 = LMHeadHasScaling
LMHeadWithScaling
scalingSpec STransformerStyle style
SByT5 = LMHeadHasScaling
LMHeadWithScaling
scalingSpec STransformerStyle style
SBART = LMHeadHasScaling
LMHeadWithoutScaling
scalingSpec STransformerStyle style
SMBART = LMHeadHasScaling
LMHeadWithoutScaling
scalingSpec STransformerStyle style
SPegasus = LMHeadHasScaling
LMHeadWithoutScaling
scalingSpec STransformerStyle style
SBERT = LMHeadHasScaling
LMHeadWithoutScaling
scalingSpec STransformerStyle style
SRoBERTa = LMHeadHasScaling
LMHeadWithoutScaling
scalingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
in forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
SDim inputEmbedDim
-> dense
-> activation
-> layerNorm
-> decoder
-> bias
-> LMHeadHasScaling
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
GLMHead SDim inputEmbedDim
inputEmbedDim (STransformerStyle style
-> ModelSpec
(LMHeadDenseF style gradient device dataType inputEmbedDim)
denseSpec STransformerStyle style
style) (STransformerStyle style -> ModelSpec (LMHeadActivationF style)
activationSpec STransformerStyle style
style) (STransformerStyle style
-> ModelSpec
(LMHeadLayerNormF style gradient device dataType inputEmbedDim)
layerNormSpec STransformerStyle style
style) (STransformerStyle style
-> ModelSpec
(LMHeadDecoderF
style gradient device dataType inputEmbedDim vocabDim)
decoderSpec STransformerStyle style
style) (STransformerStyle style
-> ModelSpec (LMHeadBiasF style gradient device dataType vocabDim)
biasSpec STransformerStyle style
style) (STransformerStyle style -> LMHeadHasScaling
scalingSpec STransformerStyle style
style)
where
linearSpec' :: ModelSpec
(GLinearF
'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
linearSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim inputEmbedDim
inputEmbedDim
biasSpec' :: TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec' = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim vocabDim
vocabDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
layerNormSpec' :: LayerNormSpec
'WithBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
inputEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
linearWithoutBiasSpec' :: ModelSpec
(GLinearF
'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim
linearWithBiasSpec' :: ModelSpec
(GLinearF
'WithBias gradient device dataType inputEmbedDim vocabDim)
linearWithBiasSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim
instance
( HasInitialize dense generatorDevice dense' generatorDevice0,
HasInitialize activation generatorDevice0 activation' generatorDevice1,
HasInitialize layerNorm generatorDevice1 layerNorm' generatorDevice2,
HasInitialize decoder generatorDevice2 decoder' generatorDevice3,
HasInitialize bias generatorDevice3 bias' generatorOutputDevice
) =>
HasInitialize
(GLMHead inputEmbedDim dense activation layerNorm decoder bias)
generatorDevice
(GLMHead inputEmbedDim dense' activation' layerNorm' decoder' bias')
generatorOutputDevice
instance HasInitialize (GBias ()) generatorDevice (GBias ()) generatorDevice where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (GBias ())
-> Generator generatorDevice
-> m (GBias (), Generator generatorDevice)
initialize (GBias ()) Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall bias. bias -> GBias bias
GBias (), Generator generatorDevice
g)
instance
HasInitialize
(GBias (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
generatorDevice
(GBias (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
(GBias
(Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
-> Generator generatorDevice
-> m (GBias
(Tensor biasGradient biasLayout biasDevice biasDataType biasShape),
Generator generatorDevice)
initialize (GBias TensorSpec
biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec) =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT (forall bias. bias -> GBias bias
GBias forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> (forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$ TensorSpec
biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec))
instance
HasInitialize
(GBias (NamedModel (Tensor biasGradient biasLayout biasDevice biasDataType biasShape)))
generatorDevice
(GBias (NamedModel (Tensor biasGradient biasLayout biasDevice biasDataType biasShape)))
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
(GBias
(NamedModel
(Tensor
biasGradient biasLayout biasDevice biasDataType biasShape)))
-> Generator generatorDevice
-> m (GBias
(NamedModel
(Tensor
biasGradient biasLayout biasDevice biasDataType biasShape)),
Generator generatorDevice)
initialize (GBias (NamedModel StateDictKey
biasName TensorSpec
biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec)) =
forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT (forall bias. bias -> GBias bias
GBias forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
biasName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros TensorSpec
biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec))
instance
( HasStateDict dense,
HasStateDict activation,
HasStateDict layerNorm,
HasStateDict decoder,
HasStateDict bias
) =>
HasStateDict (GLMHead inputEmbedDim dense activation layerNorm decoder bias)
instance HasStateDict bias => HasStateDict (GBias bias)
type family
LMHeadOutputF
(style :: TransformerStyle)
(decoderOutput :: Type)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(vocabDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
LMHeadOutputF 'T5 decoderOutput _ _ _ _ = decoderOutput
LMHeadOutputF 'ByT5 decoderOutput gradient device dataType vocabDim = LMHeadOutputF 'T5 decoderOutput gradient device dataType vocabDim
LMHeadOutputF 'BART (Tensor gradient' layout' device' dataType' shape') gradient device dataType vocabDim =
Tensor
(gradient' <|> gradient)
(layout' <+> 'Layout 'Dense)
(device' <+> device)
(dataType' <+> dataType)
(BroadcastShapesF shape' ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim]))
LMHeadOutputF 'MBART decoderOutput gradient device dataType vocabDim = LMHeadOutputF 'BART decoderOutput gradient device dataType vocabDim
LMHeadOutputF 'Pegasus decoderOutput gradient device dataType vocabDim = LMHeadOutputF 'BART decoderOutput gradient device dataType vocabDim
LMHeadOutputF 'RoBERTa decoderOutput _ _ _ _ = decoderOutput
LMHeadOutputF 'BERT decoderOutput _ _ _ _ = decoderOutput
instance
( HasForward
dense
(Tensor gradient layout device dataType shape)
generatorDevice
tensor0
generatorDevice0,
HasForward
activation
tensor0
generatorDevice0
tensor1
generatorDevice1,
HasForward
layerNorm
tensor1
generatorDevice1
tensor2
generatorDevice2,
HasForward
decoder
tensor2
generatorDevice2
(Tensor gradient3 layout3 device3 dataType3 shape3)
generatorDevice3,
HasForward
bias
(Tensor gradient3 layout3 device3 dataType3 shape3)
generatorDevice3
output
generatorOutputDevice
) =>
HasForward
(GLMHead inputEmbedDim dense activation layerNorm decoder bias)
(Tensor gradient layout device dataType shape)
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GLMHead {dense
activation
layerNorm
decoder
bias
SDim inputEmbedDim
LMHeadHasScaling
lmHeadHasScaling :: LMHeadHasScaling
lmHeadBias :: bias
lmHeadDecoder :: decoder
lmHeadLayerNorm :: layerNorm
lmHeadActivation :: activation
lmHeadDense :: dense
lmHeadInputEmbedDim :: SDim inputEmbedDim
lmHeadHasScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> LMHeadHasScaling
lmHeadBias :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> bias
lmHeadDecoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> decoder
lmHeadLayerNorm :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> layerNorm
lmHeadActivation :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> activation
lmHeadDense :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> dense
lmHeadInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> SDim inputEmbedDim
..} Tensor gradient layout device dataType shape
input =
let scaling :: Double
scaling = (Double
1 :: Double) forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
lmHeadInputEmbedDim)
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor gradient layout device dataType shape
input
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward dense
lmHeadDense
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activation
lmHeadActivation
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward layerNorm
lmHeadLayerNorm
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward decoder
lmHeadDecoder
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
LMHeadHasScaling
LMHeadWithoutScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
LMHeadHasScaling
LMHeadWithScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
)
LMHeadHasScaling
lmHeadHasScaling
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward bias
lmHeadBias
instance
HasForward
(GBias ())
(Tensor gradient layout device dataType shape)
generatorDevice
(Tensor gradient layout device dataType shape)
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GBias ()
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor gradient layout device dataType shape,
Generator generatorDevice)
forward (GBias ()
bias) = forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward ()
bias
instance
( shape' ~ BroadcastShapesF shape biasShape,
Catch shape',
output
~ Tensor
(gradient <|> biasGradient)
(layout <+> biasLayout)
(device <+> biasDevice)
(dataType <+> biasDataType)
shape'
) =>
HasForward
(GBias (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
(Tensor gradient layout device dataType shape)
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GBias
(Tensor biasGradient biasLayout biasDevice biasDataType biasShape)
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward (GBias Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias) Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
output
r <- Tensor gradient layout device dataType shape
input forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
`add` Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
r, Generator generatorDevice
g)
instance
( shape' ~ BroadcastShapesF shape biasShape,
Catch shape',
output
~ Tensor
(gradient <|> biasGradient)
(layout <+> biasLayout)
(device <+> biasDevice)
(dataType <+> biasDataType)
shape'
) =>
HasForward
(GBias (NamedModel (Tensor biasGradient biasLayout biasDevice biasDataType biasShape)))
(Tensor gradient layout device dataType shape)
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GBias
(NamedModel
(Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward (GBias (NamedModel StateDictKey
_ Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias)) Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
output
r <- Tensor gradient layout device dataType shape
input forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
shape'')
`add` Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
r, Generator generatorDevice
g)