{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.GLMHead where

import Control.Monad.Indexed (IxPointed (..), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>))
import Data.Kind (Type)
import Data.Singletons (SingKind (fromSing))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType, DataType, SDataType (..))
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Activation (Gelu (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Linear (GLinearF, linearSpec)
import Torch.GraduallyTyped.NN.Normalization (LayerNorm (..), LayerNormSpec (..))
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle (..), TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..), SHasBias (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim, SName (..), SShape (..), SSize (..), Shape (..), Size (..), pattern (:&:))
import Torch.GraduallyTyped.Tensor.Creation (sZeros)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (add, mulScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

-- | A data type that represents whether or not the language modelling head
-- has a scaled decoder output.
data LMHeadHasScaling
  = LMHeadWithScaling
  | LMHeadWithoutScaling
  deriving stock (LMHeadHasScaling -> LMHeadHasScaling -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c/= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
== :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c== :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
Eq, Eq LMHeadHasScaling
LMHeadHasScaling -> LMHeadHasScaling -> Bool
LMHeadHasScaling -> LMHeadHasScaling -> Ordering
LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
$cmin :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
max :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
$cmax :: LMHeadHasScaling -> LMHeadHasScaling -> LMHeadHasScaling
>= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c>= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
> :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c> :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
<= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c<= :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
< :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
$c< :: LMHeadHasScaling -> LMHeadHasScaling -> Bool
compare :: LMHeadHasScaling -> LMHeadHasScaling -> Ordering
$ccompare :: LMHeadHasScaling -> LMHeadHasScaling -> Ordering
Ord, Int -> LMHeadHasScaling -> ShowS
[LMHeadHasScaling] -> ShowS
LMHeadHasScaling -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMHeadHasScaling] -> ShowS
$cshowList :: [LMHeadHasScaling] -> ShowS
show :: LMHeadHasScaling -> String
$cshow :: LMHeadHasScaling -> String
showsPrec :: Int -> LMHeadHasScaling -> ShowS
$cshowsPrec :: Int -> LMHeadHasScaling -> ShowS
Show, forall x. Rep LMHeadHasScaling x -> LMHeadHasScaling
forall x. LMHeadHasScaling -> Rep LMHeadHasScaling x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep LMHeadHasScaling x -> LMHeadHasScaling
$cfrom :: forall x. LMHeadHasScaling -> Rep LMHeadHasScaling x
Generic)

type instance ModelSpec LMHeadHasScaling = LMHeadHasScaling

instance HasInitialize LMHeadHasScaling generatorDevice LMHeadHasScaling generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec LMHeadHasScaling
-> Generator generatorDevice
-> m (LMHeadHasScaling, Generator generatorDevice)
initialize ModelSpec LMHeadHasScaling
hasScaling Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec LMHeadHasScaling
hasScaling, Generator generatorDevice
g)

instance HasStateDict LMHeadHasScaling where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec LMHeadHasScaling -> StateDictKey -> m LMHeadHasScaling
fromStateDict ModelSpec LMHeadHasScaling
hasScaling StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec LMHeadHasScaling
hasScaling
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> LMHeadHasScaling -> m ()
toStateDict StateDictKey
_ LMHeadHasScaling
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Generic language modelling head for transformer encoders and decoders.
--
-- - @inputEmbedDim@ is the dimension of the input embedding.
-- - @dense@ is a dense layer.
-- - @activation@ is an activation function.
-- - @layerNorm@ is a layer normalization layer.
-- - @decoder@ is a decoder layer.
-- - @bias@ is a bias layer.
data
  GLMHead
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (dense :: Type)
    (activation :: Type)
    (layerNorm :: Type)
    (decoder :: Type)
    (bias :: Type)
  where
  GLMHead ::
    forall inputEmbedDim dense activation layerNorm decoder bias.
    { -- | the dimension of the input embedding.
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> SDim inputEmbedDim
lmHeadInputEmbedDim :: SDim inputEmbedDim,
      -- | the dense layer.
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> dense
lmHeadDense :: dense,
      -- | the activation function.
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> activation
lmHeadActivation :: activation,
      -- | the layer normalization layer.
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> layerNorm
lmHeadLayerNorm :: layerNorm,
      -- | the decoder layer.
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> decoder
lmHeadDecoder :: decoder,
      -- | the bias layer.
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> bias
lmHeadBias :: bias,
      -- | whether or not the head has a scaled decoder output.
      forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> LMHeadHasScaling
lmHeadHasScaling :: LMHeadHasScaling
    } ->
    GLMHead inputEmbedDim dense activation layerNorm decoder bias
  deriving stock (Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
 Show bias) =>
Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
 Show bias) =>
[GLMHead inputEmbedDim dense activation layerNorm decoder bias]
-> ShowS
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
 Show bias) =>
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> String
showList :: [GLMHead inputEmbedDim dense activation layerNorm decoder bias]
-> ShowS
$cshowList :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
 Show bias) =>
[GLMHead inputEmbedDim dense activation layerNorm decoder bias]
-> ShowS
show :: GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> String
$cshow :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
 Show bias) =>
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> String
showsPrec :: Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
$cshowsPrec :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
(Show dense, Show activation, Show layerNorm, Show decoder,
 Show bias) =>
Int
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias x.
Rep
  (GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias x.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> Rep
     (GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
$cto :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias x.
Rep
  (GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
$cfrom :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias x.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> Rep
     (GLMHead inputEmbedDim dense activation layerNorm decoder bias) x
Generic)

type instance
  ModelSpec (GLMHead inputEmbedDim dense activation layerNorm decoder bias) =
    GLMHead inputEmbedDim (ModelSpec dense) (ModelSpec activation) (ModelSpec layerNorm) (ModelSpec decoder) (ModelSpec bias)

-- | Generic data type for biasing the language model head.
data GBias (bias :: Type) where
  GBias :: forall bias. bias -> GBias bias
  deriving stock (GBias bias -> GBias bias -> Bool
forall bias. Eq bias => GBias bias -> GBias bias -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: GBias bias -> GBias bias -> Bool
$c/= :: forall bias. Eq bias => GBias bias -> GBias bias -> Bool
== :: GBias bias -> GBias bias -> Bool
$c== :: forall bias. Eq bias => GBias bias -> GBias bias -> Bool
Eq, GBias bias -> GBias bias -> Bool
GBias bias -> GBias bias -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {bias}. Ord bias => Eq (GBias bias)
forall bias. Ord bias => GBias bias -> GBias bias -> Bool
forall bias. Ord bias => GBias bias -> GBias bias -> Ordering
forall bias. Ord bias => GBias bias -> GBias bias -> GBias bias
min :: GBias bias -> GBias bias -> GBias bias
$cmin :: forall bias. Ord bias => GBias bias -> GBias bias -> GBias bias
max :: GBias bias -> GBias bias -> GBias bias
$cmax :: forall bias. Ord bias => GBias bias -> GBias bias -> GBias bias
>= :: GBias bias -> GBias bias -> Bool
$c>= :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
> :: GBias bias -> GBias bias -> Bool
$c> :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
<= :: GBias bias -> GBias bias -> Bool
$c<= :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
< :: GBias bias -> GBias bias -> Bool
$c< :: forall bias. Ord bias => GBias bias -> GBias bias -> Bool
compare :: GBias bias -> GBias bias -> Ordering
$ccompare :: forall bias. Ord bias => GBias bias -> GBias bias -> Ordering
Ord, Int -> GBias bias -> ShowS
forall bias. Show bias => Int -> GBias bias -> ShowS
forall bias. Show bias => [GBias bias] -> ShowS
forall bias. Show bias => GBias bias -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GBias bias] -> ShowS
$cshowList :: forall bias. Show bias => [GBias bias] -> ShowS
show :: GBias bias -> String
$cshow :: forall bias. Show bias => GBias bias -> String
showsPrec :: Int -> GBias bias -> ShowS
$cshowsPrec :: forall bias. Show bias => Int -> GBias bias -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall bias x. Rep (GBias bias) x -> GBias bias
forall bias x. GBias bias -> Rep (GBias bias) x
$cto :: forall bias x. Rep (GBias bias) x -> GBias bias
$cfrom :: forall bias x. GBias bias -> Rep (GBias bias) x
Generic)

type instance ModelSpec (GBias bias) = GBias (ModelSpec bias)

type family
  GLMHeadF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  GLMHeadF style gradient device dataType inputEmbedDim vocabDim =
    GLMHead
      inputEmbedDim
      (LMHeadDenseF style gradient device dataType inputEmbedDim)
      (LMHeadActivationF style)
      (LMHeadLayerNormF style gradient device dataType inputEmbedDim)
      (LMHeadDecoderF style gradient device dataType inputEmbedDim vocabDim)
      (LMHeadBiasF style gradient device dataType vocabDim)

-- | Specifies the dense layer of the language model head.
type family
  LMHeadDenseF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  LMHeadDenseF 'T5 _ _ _ _ =
    ()
  LMHeadDenseF 'ByT5 gradient device dataType inputEmbedDim =
    LMHeadDenseF 'T5 gradient device dataType inputEmbedDim
  LMHeadDenseF 'BART _ _ _ _ =
    ()
  LMHeadDenseF 'MBART gradient device dataType inputEmbedDim =
    LMHeadDenseF 'BART gradient device dataType inputEmbedDim
  LMHeadDenseF 'Pegasus gradient device dataType inputEmbedDim =
    LMHeadDenseF 'BART gradient device dataType inputEmbedDim
  LMHeadDenseF 'BERT gradient device dataType inputEmbedDim =
    NamedModel (GLinearF 'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
  LMHeadDenseF 'RoBERTa gradient device dataType inputEmbedDim =
    LMHeadDenseF 'BERT gradient device dataType inputEmbedDim

-- | Specifies the activation function of the language model head.
type family
  LMHeadActivationF
    (style :: TransformerStyle) ::
    Type
  where
  LMHeadActivationF 'T5 = ()
  LMHeadActivationF 'ByT5 = LMHeadActivationF 'T5
  LMHeadActivationF 'BART = ()
  LMHeadActivationF 'MBART = LMHeadActivationF 'BART
  LMHeadActivationF 'Pegasus = LMHeadActivationF 'BART
  LMHeadActivationF 'BERT = Gelu
  LMHeadActivationF 'RoBERTa = LMHeadActivationF 'BERT

-- | Specifies the layer normalization layer of the language model head.
type family
  LMHeadLayerNormF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  LMHeadLayerNormF 'T5 _ _ _ _ = ()
  LMHeadLayerNormF 'ByT5 gradient device dataType inputEmbedDim =
    LMHeadLayerNormF 'T5 gradient device dataType inputEmbedDim
  LMHeadLayerNormF 'BART _ _ _ _ = ()
  LMHeadLayerNormF 'MBART gradient device dataType inputEmbedDim =
    LMHeadLayerNormF 'BART gradient device dataType inputEmbedDim
  LMHeadLayerNormF 'Pegasus gradient device dataType inputEmbedDim =
    LMHeadLayerNormF 'BART gradient device dataType inputEmbedDim
  LMHeadLayerNormF 'BERT gradient device dataType inputEmbedDim =
    NamedModel (LayerNorm 'WithBias gradient device dataType ('Shape '[inputEmbedDim]))
  LMHeadLayerNormF 'RoBERTa gradient device dataType inputEmbedDim =
    LMHeadLayerNormF 'BERT gradient device dataType inputEmbedDim

-- | Specifies the decoder layer of the language model head.
type family
  LMHeadDecoderF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
    (vocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  LMHeadDecoderF 'T5 gradient device dataType inputEmbedDim vocabDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType inputEmbedDim vocabDim)
  LMHeadDecoderF 'ByT5 gradient device dataType inputEmbedDim vocabDim =
    LMHeadDecoderF 'T5 gradient device dataType inputEmbedDim vocabDim
  LMHeadDecoderF 'BART gradient device dataType inputEmbedDim vocabDim =
    NamedModel (GLinearF 'WithoutBias gradient device dataType inputEmbedDim vocabDim)
  LMHeadDecoderF 'MBART gradient device dataType inputEmbedDim vocabDim =
    LMHeadDecoderF 'BART gradient device dataType inputEmbedDim vocabDim
  LMHeadDecoderF 'Pegasus gradient device dataType inputEmbedDim vocabDim =
    LMHeadDecoderF 'BART gradient device dataType inputEmbedDim vocabDim
  LMHeadDecoderF 'BERT gradient device dataType inputEmbedDim vocabDim =
    NamedModel (GLinearF 'WithBias gradient device dataType inputEmbedDim vocabDim)
  LMHeadDecoderF 'RoBERTa gradient device dataType inputEmbedDim vocabDim =
    LMHeadDecoderF 'BERT gradient device dataType inputEmbedDim vocabDim

-- | Specifies the bias layer of the language model head.
type family
  LMHeadBiasF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (vocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  LMHeadBiasF 'T5 _ _ _ _ =
    GBias ()
  LMHeadBiasF 'ByT5 gradient device dataType vocabDim =
    LMHeadBiasF 'T5 gradient device dataType vocabDim
  LMHeadBiasF 'BART gradient device dataType vocabDim =
    GBias (NamedModel (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])))
  LMHeadBiasF 'MBART gradient device dataType vocabDim =
    LMHeadBiasF 'BART gradient device dataType vocabDim
  LMHeadBiasF 'Pegasus gradient device dataType vocabDim =
    LMHeadBiasF 'BART gradient device dataType vocabDim
  LMHeadBiasF 'BERT _ _ _ _ =
    GBias ()
  LMHeadBiasF 'RoBERTa gradient device dataType vocabDim =
    LMHeadBiasF 'BERT gradient device dataType vocabDim

-- | Specifies the parameters of the language model head.
lmHeadSpec ::
  forall style gradient device dataType inputEmbedDim vocabDim.
  STransformerStyle style ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim inputEmbedDim ->
  SDim vocabDim ->
  Double ->
  ModelSpec (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec :: forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat))
       (vocabDim :: Dim (Name Symbol) (Size Nat)).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputEmbedDim
-> SDim vocabDim
-> Double
-> ModelSpec
     (GLMHeadF style gradient device dataType inputEmbedDim vocabDim)
lmHeadSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim Double
eps =
  let denseSpec :: STransformerStyle style
-> ModelSpec
     (LMHeadDenseF style gradient device dataType inputEmbedDim)
denseSpec STransformerStyle style
ST5 = ()
      denseSpec STransformerStyle style
SByT5 = ()
      denseSpec STransformerStyle style
SBART = ()
      denseSpec STransformerStyle style
SMBART = ()
      denseSpec STransformerStyle style
SPegasus = ()
      denseSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"transform.dense." ModelSpec
  (GLinearF
     'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
linearSpec'
      denseSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"dense." ModelSpec
  (GLinearF
     'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
linearSpec'
      denseSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      activationSpec :: STransformerStyle style -> ModelSpec (LMHeadActivationF style)
      activationSpec :: STransformerStyle style -> ModelSpec (LMHeadActivationF style)
activationSpec STransformerStyle style
ST5 = ()
      activationSpec STransformerStyle style
SByT5 = ()
      activationSpec STransformerStyle style
SBART = ()
      activationSpec STransformerStyle style
SMBART = ()
      activationSpec STransformerStyle style
SPegasus = ()
      activationSpec STransformerStyle style
SBERT = Gelu
Gelu
      activationSpec STransformerStyle style
SRoBERTa = Gelu
Gelu
      activationSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      layerNormSpec :: STransformerStyle style
-> ModelSpec
     (LMHeadLayerNormF style gradient device dataType inputEmbedDim)
layerNormSpec STransformerStyle style
ST5 = ()
      layerNormSpec STransformerStyle style
SByT5 = ()
      layerNormSpec STransformerStyle style
SBART = ()
      layerNormSpec STransformerStyle style
SMBART = ()
      layerNormSpec STransformerStyle style
SPegasus = ()
      layerNormSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"transform.LayerNorm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec'
      layerNormSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"layer_norm." LayerNormSpec
  'WithBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec'
      layerNormSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      decoderSpec :: STransformerStyle style
-> ModelSpec
     (LMHeadDecoderF
        style gradient device dataType inputEmbedDim vocabDim)
decoderSpec STransformerStyle style
ST5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty ModelSpec
  (GLinearF
     'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
      decoderSpec STransformerStyle style
SByT5 = forall model. StateDictKey -> model -> NamedModel model
NamedModel forall a. Monoid a => a
mempty ModelSpec
  (GLinearF
     'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
      decoderSpec STransformerStyle style
SBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." ModelSpec
  (GLinearF
     'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
      decoderSpec STransformerStyle style
SMBART = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." ModelSpec
  (GLinearF
     'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
      decoderSpec STransformerStyle style
SPegasus = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"lm_head." ModelSpec
  (GLinearF
     'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec'
      decoderSpec STransformerStyle style
SBERT = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." ModelSpec
  (GLinearF
     'WithBias gradient device dataType inputEmbedDim vocabDim)
linearWithBiasSpec'
      decoderSpec STransformerStyle style
SRoBERTa = forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"decoder." ModelSpec
  (GLinearF
     'WithBias gradient device dataType inputEmbedDim vocabDim)
linearWithBiasSpec'
      decoderSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      biasSpec :: STransformerStyle style
-> ModelSpec (LMHeadBiasF style gradient device dataType vocabDim)
biasSpec STransformerStyle style
ST5 = forall bias. bias -> GBias bias
GBias ()
      biasSpec STransformerStyle style
SByT5 = forall bias. bias -> GBias bias
GBias ()
      biasSpec STransformerStyle style
SBART = forall bias. bias -> GBias bias
GBias (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_logits_bias" TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec')
      biasSpec STransformerStyle style
SMBART = forall bias. bias -> GBias bias
GBias (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_logits_bias" TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec')
      biasSpec STransformerStyle style
SPegasus = forall bias. bias -> GBias bias
GBias (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
"final_logits_bias" TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec')
      biasSpec STransformerStyle style
SBERT = forall bias. bias -> GBias bias
GBias ()
      biasSpec STransformerStyle style
SRoBERTa = forall bias. bias -> GBias bias
GBias ()
      biasSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
      scalingSpec :: STransformerStyle style -> LMHeadHasScaling
      scalingSpec :: STransformerStyle style -> LMHeadHasScaling
scalingSpec STransformerStyle style
ST5 = LMHeadHasScaling
LMHeadWithScaling
      scalingSpec STransformerStyle style
SByT5 = LMHeadHasScaling
LMHeadWithScaling
      scalingSpec STransformerStyle style
SBART = LMHeadHasScaling
LMHeadWithoutScaling
      scalingSpec STransformerStyle style
SMBART = LMHeadHasScaling
LMHeadWithoutScaling
      scalingSpec STransformerStyle style
SPegasus = LMHeadHasScaling
LMHeadWithoutScaling
      scalingSpec STransformerStyle style
SBERT = LMHeadHasScaling
LMHeadWithoutScaling
      scalingSpec STransformerStyle style
SRoBERTa = LMHeadHasScaling
LMHeadWithoutScaling
      scalingSpec STransformerStyle style
SGPT2 = forall a. HasCallStack => a
undefined
   in forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
SDim inputEmbedDim
-> dense
-> activation
-> layerNorm
-> decoder
-> bias
-> LMHeadHasScaling
-> GLMHead inputEmbedDim dense activation layerNorm decoder bias
GLMHead SDim inputEmbedDim
inputEmbedDim (STransformerStyle style
-> ModelSpec
     (LMHeadDenseF style gradient device dataType inputEmbedDim)
denseSpec STransformerStyle style
style) (STransformerStyle style -> ModelSpec (LMHeadActivationF style)
activationSpec STransformerStyle style
style) (STransformerStyle style
-> ModelSpec
     (LMHeadLayerNormF style gradient device dataType inputEmbedDim)
layerNormSpec STransformerStyle style
style) (STransformerStyle style
-> ModelSpec
     (LMHeadDecoderF
        style gradient device dataType inputEmbedDim vocabDim)
decoderSpec STransformerStyle style
style) (STransformerStyle style
-> ModelSpec (LMHeadBiasF style gradient device dataType vocabDim)
biasSpec STransformerStyle style
style) (STransformerStyle style -> LMHeadHasScaling
scalingSpec STransformerStyle style
style)
  where
    linearSpec' :: ModelSpec
  (GLinearF
     'WithBias gradient device dataType inputEmbedDim inputEmbedDim)
linearSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim inputEmbedDim
inputEmbedDim
    biasSpec' :: TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim])
biasSpec' = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim vocabDim
vocabDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
    layerNormSpec' :: LayerNormSpec
  'WithBias gradient device dataType ('Shape '[inputEmbedDim])
layerNormSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SShape normalizedShape
-> Double
-> LayerNormSpec hasBias gradient device dataType normalizedShape
LayerNormSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
inputEmbedDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil) Double
eps
    linearWithoutBiasSpec' :: ModelSpec
  (GLinearF
     'WithoutBias gradient device dataType inputEmbedDim vocabDim)
linearWithoutBiasSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithoutBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim
    linearWithBiasSpec' :: ModelSpec
  (GLinearF
     'WithBias gradient device dataType inputEmbedDim vocabDim)
linearWithBiasSpec' = forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias 'WithBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim SDim vocabDim
vocabDim

instance
  ( HasInitialize dense generatorDevice dense' generatorDevice0,
    HasInitialize activation generatorDevice0 activation' generatorDevice1,
    HasInitialize layerNorm generatorDevice1 layerNorm' generatorDevice2,
    HasInitialize decoder generatorDevice2 decoder' generatorDevice3,
    HasInitialize bias generatorDevice3 bias' generatorOutputDevice
  ) =>
  HasInitialize
    (GLMHead inputEmbedDim dense activation layerNorm decoder bias)
    generatorDevice
    (GLMHead inputEmbedDim dense' activation' layerNorm' decoder' bias')
    generatorOutputDevice

instance HasInitialize (GBias ()) generatorDevice (GBias ()) generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (GBias ())
-> Generator generatorDevice
-> m (GBias (), Generator generatorDevice)
initialize (GBias ()) Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall bias. bias -> GBias bias
GBias (), Generator generatorDevice
g)

instance
  HasInitialize
    (GBias (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
    generatorDevice
    (GBias (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
  (GBias
     (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
-> Generator generatorDevice
-> m (GBias
        (Tensor biasGradient biasLayout biasDevice biasDataType biasShape),
      Generator generatorDevice)
initialize (GBias TensorSpec
  biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec) =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT (forall bias. bias -> GBias bias
GBias forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> (forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$ TensorSpec
  biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec))

instance
  HasInitialize
    (GBias (NamedModel (Tensor biasGradient biasLayout biasDevice biasDataType biasShape)))
    generatorDevice
    (GBias (NamedModel (Tensor biasGradient biasLayout biasDevice biasDataType biasShape)))
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
  (GBias
     (NamedModel
        (Tensor
           biasGradient biasLayout biasDevice biasDataType biasShape)))
-> Generator generatorDevice
-> m (GBias
        (NamedModel
           (Tensor
              biasGradient biasLayout biasDevice biasDataType biasShape)),
      Generator generatorDevice)
initialize (GBias (NamedModel StateDictKey
biasName TensorSpec
  biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec)) =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT (forall bias. bias -> GBias bias
GBias forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift (forall model. StateDictKey -> model -> NamedModel model
NamedModel StateDictKey
biasName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros TensorSpec
  biasGradient biasLayout biasDevice biasDataType biasShape
biasSpec))

instance
  ( HasStateDict dense,
    HasStateDict activation,
    HasStateDict layerNorm,
    HasStateDict decoder,
    HasStateDict bias
  ) =>
  HasStateDict (GLMHead inputEmbedDim dense activation layerNorm decoder bias)

instance HasStateDict bias => HasStateDict (GBias bias)

type family
  LMHeadOutputF
    (style :: TransformerStyle)
    (decoderOutput :: Type)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (vocabDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  LMHeadOutputF 'T5 decoderOutput _ _ _ _ = decoderOutput
  LMHeadOutputF 'ByT5 decoderOutput gradient device dataType vocabDim = LMHeadOutputF 'T5 decoderOutput gradient device dataType vocabDim
  LMHeadOutputF 'BART (Tensor gradient' layout' device' dataType' shape') gradient device dataType vocabDim =
    Tensor
      (gradient' <|> gradient)
      (layout' <+> 'Layout 'Dense)
      (device' <+> device)
      (dataType' <+> dataType)
      (BroadcastShapesF shape' ('Shape '[ 'Dim ('Name "*") ('Size 1), vocabDim]))
  LMHeadOutputF 'MBART decoderOutput gradient device dataType vocabDim = LMHeadOutputF 'BART decoderOutput gradient device dataType vocabDim
  LMHeadOutputF 'Pegasus decoderOutput gradient device dataType vocabDim = LMHeadOutputF 'BART decoderOutput gradient device dataType vocabDim
  LMHeadOutputF 'RoBERTa decoderOutput _ _ _ _ = decoderOutput
  LMHeadOutputF 'BERT decoderOutput _ _ _ _ = decoderOutput

-- | 'HasForward' instance for 'LMHead'.
--
-- @
--     ┌───────┐
--     │ input │
--     └───┬───┘
--         │
--         ▼
--   (lmHeadDense)
--         ▼
-- (lmHeadActivation)
--         ▼
-- (lmHeadLayerNorm)
--         ▼
--   lmHeadDecoder
--         ▼
--     (scaling)
--         ▼
--    (lmHeadBias)
--         │
--         ▼
-- ┌───────────────┐
-- │ decoderOutput │
-- └───────────────┘
-- @
instance
  ( HasForward
      dense
      (Tensor gradient layout device dataType shape)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      activation
      tensor0
      generatorDevice0
      tensor1
      generatorDevice1,
    HasForward
      layerNorm
      tensor1
      generatorDevice1
      tensor2
      generatorDevice2,
    HasForward
      decoder
      tensor2
      generatorDevice2
      (Tensor gradient3 layout3 device3 dataType3 shape3)
      generatorDevice3,
    HasForward
      bias
      (Tensor gradient3 layout3 device3 dataType3 shape3)
      generatorDevice3
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GLMHead inputEmbedDim dense activation layerNorm decoder bias)
    (Tensor gradient layout device dataType shape)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GLMHead {dense
activation
layerNorm
decoder
bias
SDim inputEmbedDim
LMHeadHasScaling
lmHeadHasScaling :: LMHeadHasScaling
lmHeadBias :: bias
lmHeadDecoder :: decoder
lmHeadLayerNorm :: layerNorm
lmHeadActivation :: activation
lmHeadDense :: dense
lmHeadInputEmbedDim :: SDim inputEmbedDim
lmHeadHasScaling :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> LMHeadHasScaling
lmHeadBias :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> bias
lmHeadDecoder :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> decoder
lmHeadLayerNorm :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> layerNorm
lmHeadActivation :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> activation
lmHeadDense :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> dense
lmHeadInputEmbedDim :: forall (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) dense
       activation layerNorm decoder bias.
GLMHead inputEmbedDim dense activation layerNorm decoder bias
-> SDim inputEmbedDim
..} Tensor gradient layout device dataType shape
input =
    let scaling :: Double
scaling = (Double
1 :: Double) forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim inputEmbedDim
lmHeadInputEmbedDim)
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
          forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor gradient layout device dataType shape
input
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward dense
lmHeadDense
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activation
lmHeadActivation
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward layerNorm
lmHeadLayerNorm
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward decoder
lmHeadDecoder
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \case
                    LMHeadHasScaling
LMHeadWithoutScaling -> forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    LMHeadHasScaling
LMHeadWithScaling -> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
mulScalar Double
scaling
                )
                LMHeadHasScaling
lmHeadHasScaling
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward bias
lmHeadBias

instance
  HasForward
    (GBias ())
    (Tensor gradient layout device dataType shape)
    generatorDevice
    (Tensor gradient layout device dataType shape)
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GBias ()
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor gradient layout device dataType shape,
      Generator generatorDevice)
forward (GBias ()
bias) = forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward ()
bias

instance
  ( shape' ~ BroadcastShapesF shape biasShape,
    Catch shape',
    output
      ~ Tensor
          (gradient <|> biasGradient)
          (layout <+> biasLayout)
          (device <+> biasDevice)
          (dataType <+> biasDataType)
          shape'
  ) =>
  HasForward
    (GBias (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
    (Tensor gradient layout device dataType shape)
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GBias
  (Tensor biasGradient biasLayout biasDevice biasDataType biasShape)
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward (GBias Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias) Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
    output
r <- Tensor gradient layout device dataType shape
input forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`add` Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
r, Generator generatorDevice
g)

instance
  ( shape' ~ BroadcastShapesF shape biasShape,
    Catch shape',
    output
      ~ Tensor
          (gradient <|> biasGradient)
          (layout <+> biasLayout)
          (device <+> biasDevice)
          (dataType <+> biasDataType)
          shape'
  ) =>
  HasForward
    (GBias (NamedModel (Tensor biasGradient biasLayout biasDevice biasDataType biasShape)))
    (Tensor gradient layout device dataType shape)
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GBias
  (NamedModel
     (Tensor biasGradient biasLayout biasDevice biasDataType biasShape))
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward (GBias (NamedModel StateDictKey
_ Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias)) Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
    output
r <- Tensor gradient layout device dataType shape
input forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`add` Tensor biasGradient biasLayout biasDevice biasDataType biasShape
bias
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
r, Generator generatorDevice
g)