{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin TypeLevel.Rewrite
                -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyIdempotenceL2
                -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.OrIdempotenceL2 #-}

module Torch.GraduallyTyped.NN.Normalization where

import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (Layout), LayoutType (Dense), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.NN.Functional.Normalization (LayerNormWithBiasF, LayerNormWithoutBiasF, layerNormWithBias, layerNormWithoutBias)
import Torch.GraduallyTyped.NN.Type (HasBias (..), SHasBias (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient)
import Torch.GraduallyTyped.Shape (Dim (..), Name (..), SShape (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Creation (sOnes, sZeros)
import Torch.GraduallyTyped.Tensor.Type (SGetShape, Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

data
  LayerNorm
    (hasBias :: HasBias)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)])
  where
  LayerNormWithBias ::
    forall gradient device dataType normalizedShape.
    { forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape,
      forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormBias :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape,
      forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Double
layerNormWithBiasEps :: Double
    } ->
    LayerNorm 'WithBias gradient device dataType normalizedShape
  LayerNormWithoutBias ::
    forall gradient device dataType normalizedShape.
    { forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithoutBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithoutBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape,
      forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithoutBias gradient device dataType normalizedShape
-> Double
layerNormWithoutBiasEps :: Double
    } ->
    LayerNorm 'WithoutBias gradient device dataType normalizedShape

data
  LayerNormSpec
    (hasBias :: HasBias)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)])
  where
  LayerNormSpec ::
    forall hasBias gradient device dataType normalizedShape.
    SHasBias hasBias ->
    SGradient gradient ->
    SDevice device ->
    SDataType dataType ->
    SShape normalizedShape ->
    Double ->
    LayerNormSpec hasBias gradient device dataType normalizedShape
  deriving stock (Int
-> LayerNormSpec hasBias gradient device dataType normalizedShape
-> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> LayerNormSpec hasBias gradient device dataType normalizedShape
-> ShowS
forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
[LayerNormSpec hasBias gradient device dataType normalizedShape]
-> ShowS
forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNormSpec hasBias gradient device dataType normalizedShape
-> String
showList :: [LayerNormSpec hasBias gradient device dataType normalizedShape]
-> ShowS
$cshowList :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
[LayerNormSpec hasBias gradient device dataType normalizedShape]
-> ShowS
show :: LayerNormSpec hasBias gradient device dataType normalizedShape
-> String
$cshow :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNormSpec hasBias gradient device dataType normalizedShape
-> String
showsPrec :: Int
-> LayerNormSpec hasBias gradient device dataType normalizedShape
-> ShowS
$cshowsPrec :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> LayerNormSpec hasBias gradient device dataType normalizedShape
-> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep
  (LayerNormSpec hasBias gradient device dataType normalizedShape) x
-> LayerNormSpec hasBias gradient device dataType normalizedShape
forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
LayerNormSpec hasBias gradient device dataType normalizedShape
-> Rep
     (LayerNormSpec hasBias gradient device dataType normalizedShape) x
$cto :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep
  (LayerNormSpec hasBias gradient device dataType normalizedShape) x
-> LayerNormSpec hasBias gradient device dataType normalizedShape
$cfrom :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
LayerNormSpec hasBias gradient device dataType normalizedShape
-> Rep
     (LayerNormSpec hasBias gradient device dataType normalizedShape) x
Generic)

type instance ModelSpec (LayerNorm hasBias gradient device dataType normalizedShape) = LayerNormSpec hasBias gradient device dataType normalizedShape

instance
  HasInitialize
    (LayerNorm hasBias gradient device dataType normalizedShape)
    generatorDevice
    (LayerNorm hasBias gradient device dataType normalizedShape)
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
  (LayerNorm hasBias gradient device dataType normalizedShape)
-> Generator generatorDevice
-> m (LayerNorm hasBias gradient device dataType normalizedShape,
      Generator generatorDevice)
initialize (LayerNormSpec SHasBias hasBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape Double
eps) Generator generatorDevice
g = do
    let tensorSpec :: TensorSpec
  gradient ('Layout 'Dense) device dataType normalizedShape
tensorSpec = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape
    Tensor gradient ('Layout 'Dense) device dataType normalizedShape
weight <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes TensorSpec
  gradient ('Layout 'Dense) device dataType normalizedShape
tensorSpec
    Tensor gradient ('Layout 'Dense) device dataType normalizedShape
bias <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros TensorSpec
  gradient ('Layout 'Dense) device dataType normalizedShape
tensorSpec
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
-> Double
-> LayerNorm 'WithBias gradient device dataType normalizedShape
LayerNormWithBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape
weight Tensor gradient ('Layout 'Dense) device dataType normalizedShape
bias Double
eps, Generator generatorDevice
g)
  initialize (LayerNormSpec SHasBias hasBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape Double
eps) Generator generatorDevice
g = do
    let tensorSpec :: TensorSpec
  gradient ('Layout 'Dense) device dataType normalizedShape
tensorSpec = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape
    Tensor gradient ('Layout 'Dense) device dataType normalizedShape
weight <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes TensorSpec
  gradient ('Layout 'Dense) device dataType normalizedShape
tensorSpec
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
-> Double
-> LayerNorm 'WithoutBias gradient device dataType normalizedShape
LayerNormWithoutBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape
weight Double
eps, Generator generatorDevice
g)

instance
  HasStateDict
    (LayerNorm hasBias gradient device dataType normalizedShape)
  where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec
  (LayerNorm hasBias gradient device dataType normalizedShape)
-> StateDictKey
-> m (LayerNorm hasBias gradient device dataType normalizedShape)
fromStateDict (LayerNormSpec SHasBias hasBias
SWithBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape Double
eps) StateDictKey
k =
    forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
-> Double
-> LayerNorm 'WithBias gradient device dataType normalizedShape
LayerNormWithBias
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape) (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight")
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape) (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"bias")
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
eps
  fromStateDict (LayerNormSpec SHasBias hasBias
SWithoutBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape Double
eps) StateDictKey
k =
    forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
-> Double
-> LayerNorm 'WithoutBias gradient device dataType normalizedShape
LayerNormWithoutBias
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType SShape normalizedShape
normalizedShape) (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight")
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
eps
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey
-> LayerNorm hasBias gradient device dataType normalizedShape
-> m ()
toStateDict StateDictKey
k LayerNormWithBias {Double
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasEps :: Double
layerNormBias :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasEps :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Double
layerNormBias :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasWeight :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
..} = do
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight") Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasWeight
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"bias") Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormBias
  toStateDict StateDictKey
k LayerNormWithoutBias {Double
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithoutBiasEps :: Double
layerNormWithoutBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithoutBiasEps :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithoutBias gradient device dataType normalizedShape
-> Double
layerNormWithoutBiasWeight :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithoutBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
..} =
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict (StateDictKey
k forall a. Semigroup a => a -> a -> a
<> StateDictKey
"weight") Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithoutBiasWeight

instance
  ( SGetShape normalizedShape,
    output
      ~ Tensor
          (gradient <|> gradient')
          ('Layout 'Dense <+> layout')
          (device <+> device')
          (dataType <+> dataType')
          (LayerNormWithBiasF normalizedShape normalizedShape shape')
  ) =>
  HasForward
    (LayerNorm 'WithBias gradient device dataType normalizedShape)
    (Tensor gradient' layout' device' dataType' shape')
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward LayerNormWithBias {Double
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasEps :: Double
layerNormBias :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasEps :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Double
layerNormBias :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasWeight :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (gradient' :: Gradient RequiresGradient)
       (gradient'' :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (layout' :: Layout LayoutType)
       (layout'' :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat))
       (device'' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (dataType' :: DataType DType) (dataType'' :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Double
-> Tensor gradient'' layout'' device'' dataType'' shape''
-> Tensor
     (gradient' <|> (gradient' <|> gradient''))
     (layout <+> (layout' <+> layout''))
     (device <+> (device' <+> device''))
     (dataType <+> (dataType' <+> dataType''))
     (LayerNormWithBiasF shape shape' shape'')
layerNormWithBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithBiasWeight Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormBias Double
layerNormWithBiasEps Tensor gradient' layout' device' dataType' shape'
input,)

instance
  ( SGetShape normalizedShape,
    SGetShape shape',
    output
      ~ Tensor
          (gradient <|> gradient')
          ('Layout 'Dense <+> layout')
          (device <+> device')
          (dataType <+> dataType')
          (LayerNormWithoutBiasF normalizedShape shape')
  ) =>
  HasForward
    (LayerNorm 'WithoutBias gradient device dataType normalizedShape)
    (Tensor gradient' layout' device' dataType' shape')
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
LayerNorm 'WithoutBias gradient device dataType normalizedShape
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward LayerNormWithoutBias {Double
Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithoutBiasEps :: Double
layerNormWithoutBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithoutBiasEps :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithoutBias gradient device dataType normalizedShape
-> Double
layerNormWithoutBiasWeight :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]).
LayerNorm 'WithoutBias gradient device dataType normalizedShape
-> Tensor gradient ('Layout 'Dense) device dataType normalizedShape
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, SGetShape shape') =>
Tensor gradient layout device dataType shape
-> Double
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
     (gradient <|> gradient')
     (layout <+> layout')
     (device <+> device')
     (dataType <+> dataType')
     (LayerNormWithoutBiasF shape shape')
layerNormWithoutBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape
layerNormWithoutBiasWeight Double
layerNormWithoutBiasEps Tensor gradient' layout' device' dataType' shape'
input,)