{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fplugin TypeLevel.Rewrite -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyIdempotenceL2 -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.OrIdempotenceL2 #-} module Torch.GraduallyTyped.NN.Normalization where import GHC.Generics (Generic) import GHC.TypeLits (Nat, Symbol) import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..)) import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..)) import Torch.GraduallyTyped.Layout (Layout (Layout), LayoutType (Dense), SLayout (..), SLayoutType (..)) import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec) import Torch.GraduallyTyped.NN.Functional.Normalization (LayerNormWithBiasF, LayerNormWithoutBiasF, layerNormWithBias, layerNormWithoutBias) import Torch.GraduallyTyped.NN.Type (HasBias (..), SHasBias (..)) import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient) import Torch.GraduallyTyped.Shape (Dim (..), Name (..), SShape (..), Shape (..), Size (..)) import Torch.GraduallyTyped.Tensor.Creation (sOnes, sZeros) import Torch.GraduallyTyped.Tensor.Type (SGetShape, Tensor, TensorSpec (..)) import Torch.GraduallyTyped.Unify (type (<+>), type (<|>)) data LayerNorm (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) where LayerNormWithBias :: forall gradient device dataType normalizedShape. { forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape, forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormBias :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape, forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Double layerNormWithBiasEps :: Double } -> LayerNorm 'WithBias gradient device dataType normalizedShape LayerNormWithoutBias :: forall gradient device dataType normalizedShape. { forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithoutBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithoutBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape, forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithoutBias gradient device dataType normalizedShape -> Double layerNormWithoutBiasEps :: Double } -> LayerNorm 'WithoutBias gradient device dataType normalizedShape data LayerNormSpec (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) where LayerNormSpec :: forall hasBias gradient device dataType normalizedShape. SHasBias hasBias -> SGradient gradient -> SDevice device -> SDataType dataType -> SShape normalizedShape -> Double -> LayerNormSpec hasBias gradient device dataType normalizedShape deriving stock (Int -> LayerNormSpec hasBias gradient device dataType normalizedShape -> ShowS forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). Int -> LayerNormSpec hasBias gradient device dataType normalizedShape -> ShowS forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). [LayerNormSpec hasBias gradient device dataType normalizedShape] -> ShowS forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNormSpec hasBias gradient device dataType normalizedShape -> String showList :: [LayerNormSpec hasBias gradient device dataType normalizedShape] -> ShowS $cshowList :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). [LayerNormSpec hasBias gradient device dataType normalizedShape] -> ShowS show :: LayerNormSpec hasBias gradient device dataType normalizedShape -> String $cshow :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNormSpec hasBias gradient device dataType normalizedShape -> String showsPrec :: Int -> LayerNormSpec hasBias gradient device dataType normalizedShape -> ShowS $cshowsPrec :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). Int -> LayerNormSpec hasBias gradient device dataType normalizedShape -> ShowS Show, forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x. Rep (LayerNormSpec hasBias gradient device dataType normalizedShape) x -> LayerNormSpec hasBias gradient device dataType normalizedShape forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x. LayerNormSpec hasBias gradient device dataType normalizedShape -> Rep (LayerNormSpec hasBias gradient device dataType normalizedShape) x $cto :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x. Rep (LayerNormSpec hasBias gradient device dataType normalizedShape) x -> LayerNormSpec hasBias gradient device dataType normalizedShape $cfrom :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]) x. LayerNormSpec hasBias gradient device dataType normalizedShape -> Rep (LayerNormSpec hasBias gradient device dataType normalizedShape) x Generic) type instance ModelSpec (LayerNorm hasBias gradient device dataType normalizedShape) = LayerNormSpec hasBias gradient device dataType normalizedShape instance HasInitialize (LayerNorm hasBias gradient device dataType normalizedShape) generatorDevice (LayerNorm hasBias gradient device dataType normalizedShape) generatorDevice where initialize :: forall (m :: * -> *). MonadThrow m => ModelSpec (LayerNorm hasBias gradient device dataType normalizedShape) -> Generator generatorDevice -> m (LayerNorm hasBias gradient device dataType normalizedShape, Generator generatorDevice) initialize (LayerNormSpec SHasBias hasBias SWithBias SGradient gradient gradient SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape Double eps) Generator generatorDevice g = do let tensorSpec :: TensorSpec gradient ('Layout 'Dense) device dataType normalizedShape tensorSpec = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGradient gradient -> SLayout layout -> SDevice device -> SDataType dataType -> SShape shape -> TensorSpec gradient layout device dataType shape TensorSpec SGradient gradient gradient (forall (layoutType :: LayoutType). SLayoutType layoutType -> SLayout ('Layout layoutType) SLayout SLayoutType 'Dense SDense) SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape Tensor gradient ('Layout 'Dense) device dataType normalizedShape weight <- forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). MonadThrow m => TensorSpec gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape) sOnes TensorSpec gradient ('Layout 'Dense) device dataType normalizedShape tensorSpec Tensor gradient ('Layout 'Dense) device dataType normalizedShape bias <- forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). MonadThrow m => TensorSpec gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape) sZeros TensorSpec gradient ('Layout 'Dense) device dataType normalizedShape tensorSpec forall (f :: * -> *) a. Applicative f => a -> f a pure (forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). Tensor gradient ('Layout 'Dense) device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape -> Double -> LayerNorm 'WithBias gradient device dataType normalizedShape LayerNormWithBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape weight Tensor gradient ('Layout 'Dense) device dataType normalizedShape bias Double eps, Generator generatorDevice g) initialize (LayerNormSpec SHasBias hasBias SWithoutBias SGradient gradient gradient SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape Double eps) Generator generatorDevice g = do let tensorSpec :: TensorSpec gradient ('Layout 'Dense) device dataType normalizedShape tensorSpec = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGradient gradient -> SLayout layout -> SDevice device -> SDataType dataType -> SShape shape -> TensorSpec gradient layout device dataType shape TensorSpec SGradient gradient gradient (forall (layoutType :: LayoutType). SLayoutType layoutType -> SLayout ('Layout layoutType) SLayout SLayoutType 'Dense SDense) SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape Tensor gradient ('Layout 'Dense) device dataType normalizedShape weight <- forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). MonadThrow m => TensorSpec gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape) sOnes TensorSpec gradient ('Layout 'Dense) device dataType normalizedShape tensorSpec forall (f :: * -> *) a. Applicative f => a -> f a pure (forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). Tensor gradient ('Layout 'Dense) device dataType normalizedShape -> Double -> LayerNorm 'WithoutBias gradient device dataType normalizedShape LayerNormWithoutBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape weight Double eps, Generator generatorDevice g) instance HasStateDict (LayerNorm hasBias gradient device dataType normalizedShape) where fromStateDict :: forall (m :: * -> *). (MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec (LayerNorm hasBias gradient device dataType normalizedShape) -> StateDictKey -> m (LayerNorm hasBias gradient device dataType normalizedShape) fromStateDict (LayerNormSpec SHasBias hasBias SWithBias SGradient gradient gradient SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape Double eps) StateDictKey k = forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). Tensor gradient ('Layout 'Dense) device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape -> Double -> LayerNorm 'WithBias gradient device dataType normalizedShape LayerNormWithBias forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> forall model (m :: * -> *). (HasStateDict model, MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec model -> StateDictKey -> m model fromStateDict (forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGradient gradient -> SLayout layout -> SDevice device -> SDataType dataType -> SShape shape -> TensorSpec gradient layout device dataType shape TensorSpec SGradient gradient gradient (forall (layoutType :: LayoutType). SLayoutType layoutType -> SLayout ('Layout layoutType) SLayout SLayoutType 'Dense SDense) SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape) (StateDictKey k forall a. Semigroup a => a -> a -> a <> StateDictKey "weight") forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> forall model (m :: * -> *). (HasStateDict model, MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec model -> StateDictKey -> m model fromStateDict (forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGradient gradient -> SLayout layout -> SDevice device -> SDataType dataType -> SShape shape -> TensorSpec gradient layout device dataType shape TensorSpec SGradient gradient gradient (forall (layoutType :: LayoutType). SLayoutType layoutType -> SLayout ('Layout layoutType) SLayout SLayoutType 'Dense SDense) SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape) (StateDictKey k forall a. Semigroup a => a -> a -> a <> StateDictKey "bias") forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> forall (f :: * -> *) a. Applicative f => a -> f a pure Double eps fromStateDict (LayerNormSpec SHasBias hasBias SWithoutBias SGradient gradient gradient SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape Double eps) StateDictKey k = forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). Tensor gradient ('Layout 'Dense) device dataType normalizedShape -> Double -> LayerNorm 'WithoutBias gradient device dataType normalizedShape LayerNormWithoutBias forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> forall model (m :: * -> *). (HasStateDict model, MonadIO m, MonadThrow m, MonadState StateDict m) => ModelSpec model -> StateDictKey -> m model fromStateDict (forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGradient gradient -> SLayout layout -> SDevice device -> SDataType dataType -> SShape shape -> TensorSpec gradient layout device dataType shape TensorSpec SGradient gradient gradient (forall (layoutType :: LayoutType). SLayoutType layoutType -> SLayout ('Layout layoutType) SLayout SLayoutType 'Dense SDense) SDevice device device SDataType dataType dataType SShape normalizedShape normalizedShape) (StateDictKey k forall a. Semigroup a => a -> a -> a <> StateDictKey "weight") forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> forall (f :: * -> *) a. Applicative f => a -> f a pure Double eps toStateDict :: forall (m :: * -> *). (MonadThrow m, MonadState StateDict m) => StateDictKey -> LayerNorm hasBias gradient device dataType normalizedShape -> m () toStateDict StateDictKey k LayerNormWithBias {Double Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasEps :: Double layerNormBias :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasEps :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Double layerNormBias :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasWeight :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape ..} = do forall model (m :: * -> *). (HasStateDict model, MonadThrow m, MonadState StateDict m) => StateDictKey -> model -> m () toStateDict (StateDictKey k forall a. Semigroup a => a -> a -> a <> StateDictKey "weight") Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasWeight forall model (m :: * -> *). (HasStateDict model, MonadThrow m, MonadState StateDict m) => StateDictKey -> model -> m () toStateDict (StateDictKey k forall a. Semigroup a => a -> a -> a <> StateDictKey "bias") Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormBias toStateDict StateDictKey k LayerNormWithoutBias {Double Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithoutBiasEps :: Double layerNormWithoutBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithoutBiasEps :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithoutBias gradient device dataType normalizedShape -> Double layerNormWithoutBiasWeight :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithoutBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape ..} = forall model (m :: * -> *). (HasStateDict model, MonadThrow m, MonadState StateDict m) => StateDictKey -> model -> m () toStateDict (StateDictKey k forall a. Semigroup a => a -> a -> a <> StateDictKey "weight") Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithoutBiasWeight instance ( SGetShape normalizedShape, output ~ Tensor (gradient <|> gradient') ('Layout 'Dense <+> layout') (device <+> device') (dataType <+> dataType') (LayerNormWithBiasF normalizedShape normalizedShape shape') ) => HasForward (LayerNorm 'WithBias gradient device dataType normalizedShape) (Tensor gradient' layout' device' dataType' shape') generatorDevice output generatorDevice where forward :: forall (m :: * -> *). MonadThrow m => LayerNorm 'WithBias gradient device dataType normalizedShape -> Tensor gradient' layout' device' dataType' shape' -> Generator generatorDevice -> m (output, Generator generatorDevice) forward LayerNormWithBias {Double Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasEps :: Double layerNormBias :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasEps :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Double layerNormBias :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasWeight :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape ..} Tensor gradient' layout' device' dataType' shape' input = forall (f :: * -> *) a. Applicative f => a -> f a pure forall b c a. (b -> c) -> (a -> b) -> a -> c . (forall (gradient :: Gradient RequiresGradient) (gradient' :: Gradient RequiresGradient) (gradient'' :: Gradient RequiresGradient) (layout :: Layout LayoutType) (layout' :: Layout LayoutType) (layout'' :: Layout LayoutType) (device :: Device (DeviceType Nat)) (device' :: Device (DeviceType Nat)) (device'' :: Device (DeviceType Nat)) (dataType :: DataType DType) (dataType' :: DataType DType) (dataType'' :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]). SGetShape shape => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> Double -> Tensor gradient'' layout'' device'' dataType'' shape'' -> Tensor (gradient' <|> (gradient' <|> gradient'')) (layout <+> (layout' <+> layout'')) (device <+> (device' <+> device'')) (dataType <+> (dataType' <+> dataType'')) (LayerNormWithBiasF shape shape' shape'') layerNormWithBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithBiasWeight Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormBias Double layerNormWithBiasEps Tensor gradient' layout' device' dataType' shape' input,) instance ( SGetShape normalizedShape, SGetShape shape', output ~ Tensor (gradient <|> gradient') ('Layout 'Dense <+> layout') (device <+> device') (dataType <+> dataType') (LayerNormWithoutBiasF normalizedShape shape') ) => HasForward (LayerNorm 'WithoutBias gradient device dataType normalizedShape) (Tensor gradient' layout' device' dataType' shape') generatorDevice output generatorDevice where forward :: forall (m :: * -> *). MonadThrow m => LayerNorm 'WithoutBias gradient device dataType normalizedShape -> Tensor gradient' layout' device' dataType' shape' -> Generator generatorDevice -> m (output, Generator generatorDevice) forward LayerNormWithoutBias {Double Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithoutBiasEps :: Double layerNormWithoutBiasWeight :: Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithoutBiasEps :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithoutBias gradient device dataType normalizedShape -> Double layerNormWithoutBiasWeight :: forall (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (normalizedShape :: Shape [Dim (Name Symbol) (Size Nat)]). LayerNorm 'WithoutBias gradient device dataType normalizedShape -> Tensor gradient ('Layout 'Dense) device dataType normalizedShape ..} Tensor gradient' layout' device' dataType' shape' input = forall (f :: * -> *) a. Applicative f => a -> f a pure forall b c a. (b -> c) -> (a -> b) -> a -> c . (forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]). (SGetShape shape, SGetShape shape') => Tensor gradient layout device dataType shape -> Double -> Tensor gradient' layout' device' dataType' shape' -> Tensor (gradient <|> gradient') (layout <+> layout') (device <+> device') (dataType <+> dataType') (LayerNormWithoutBiasF shape shape') layerNormWithoutBias Tensor gradient ('Layout 'Dense) device dataType normalizedShape layerNormWithoutBiasWeight Double layerNormWithoutBiasEps Tensor gradient' layout' device' dataType' shape' input,)