{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE UndecidableInstances #-} module Torch.Typed.NN.Normalization where import GHC.Generics import GHC.TypeLits import qualified Torch.DType as D import qualified Torch.Device as D import Torch.NN (HasForward (..), Randomizable (..)) import Torch.Typed.Auxiliary import Torch.Typed.Factories import Torch.Typed.Functional import Torch.Typed.Parameter import Torch.Typed.Tensor data LayerNormSpec (normalizedShape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where LayerNormSpec :: forall normalizedShape dtype device. {forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> Double layerNormEpsSpec :: Double} -> LayerNormSpec normalizedShape dtype device deriving (Int -> LayerNormSpec normalizedShape dtype device -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNormSpec normalizedShape dtype device -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNormSpec normalizedShape dtype device] -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [LayerNormSpec normalizedShape dtype device] -> ShowS $cshowList :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNormSpec normalizedShape dtype device] -> ShowS show :: LayerNormSpec normalizedShape dtype device -> String $cshow :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> String showsPrec :: Int -> LayerNormSpec normalizedShape dtype device -> ShowS $cshowsPrec :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNormSpec normalizedShape dtype device -> ShowS Show, LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a /= :: LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool $c/= :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool == :: LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool $c== :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool Eq) data LayerNorm (normalizedShape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where LayerNorm :: { forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape layerNormWeight :: Parameter device dtype normalizedShape, forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape layerNormBias :: Parameter device dtype normalizedShape, forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Double layerNormEps :: Double } -> LayerNorm normalizedShape dtype device deriving (Int -> LayerNorm normalizedShape dtype device -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNorm normalizedShape dtype device -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNorm normalizedShape dtype device] -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [LayerNorm normalizedShape dtype device] -> ShowS $cshowList :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNorm normalizedShape dtype device] -> ShowS show :: LayerNorm normalizedShape dtype device -> String $cshow :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> String showsPrec :: Int -> LayerNorm normalizedShape dtype device -> ShowS $cshowsPrec :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNorm normalizedShape dtype device -> ShowS Show, forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. Rep (LayerNorm normalizedShape dtype device) x -> LayerNorm normalizedShape dtype device forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. LayerNorm normalizedShape dtype device -> Rep (LayerNorm normalizedShape dtype device) x forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a $cto :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. Rep (LayerNorm normalizedShape dtype device) x -> LayerNorm normalizedShape dtype device $cfrom :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. LayerNorm normalizedShape dtype device -> Rep (LayerNorm normalizedShape dtype device) x Generic, forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device forall f. (f -> HList (Parameters f)) -> (f -> HList (Parameters f) -> f) -> Parameterized f replaceParameters :: LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device $creplaceParameters :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device flattenParameters :: LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) $cflattenParameters :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) Parameterized) layerNormForward :: forall normalizedShape shape dtype device. ( IsSuffixOf normalizedShape shape, KnownShape normalizedShape ) => LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape layerNormForward :: forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (IsSuffixOf normalizedShape shape, KnownShape normalizedShape) => LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape layerNormForward LayerNorm {Double Parameter device dtype normalizedShape layerNormEps :: Double layerNormBias :: Parameter device dtype normalizedShape layerNormWeight :: Parameter device dtype normalizedShape layerNormEps :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Double layerNormBias :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape layerNormWeight :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape ..} = forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (KnownShape normalizedShape, IsSuffixOf normalizedShape shape) => Tensor device dtype normalizedShape -> Tensor device dtype normalizedShape -> Double -> Tensor device dtype shape -> Tensor device dtype shape layerNorm @normalizedShape (forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent Parameter device dtype normalizedShape layerNormWeight) (forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent Parameter device dtype normalizedShape layerNormBias) Double layerNormEps instance ( IsSuffixOf normalizedShape shape, KnownShape normalizedShape ) => HasForward (LayerNorm normalizedShape dtype device) (Tensor device dtype shape) (Tensor device dtype shape) where forward :: LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape forward = forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (IsSuffixOf normalizedShape shape, KnownShape normalizedShape) => LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape layerNormForward forwardStoch :: LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> IO (Tensor device dtype shape) forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a pure forall b c a. (b -> c) -> (a -> b) -> a -> c .) forall b c a. (b -> c) -> (a -> b) -> a -> c . forall f a b. HasForward f a b => f -> a -> b forward instance ( TensorOptions normalizedShape dtype device, RandDTypeIsValid device dtype ) => Randomizable (LayerNormSpec normalizedShape dtype device) (LayerNorm normalizedShape dtype device) where sample :: LayerNormSpec normalizedShape dtype device -> IO (LayerNorm normalizedShape dtype device) sample LayerNormSpec {Double layerNormEpsSpec :: Double layerNormEpsSpec :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> Double ..} = forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). Parameter device dtype normalizedShape -> Parameter device dtype normalizedShape -> Double -> LayerNorm normalizedShape dtype device LayerNorm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> (forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b =<< forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (TensorOptions shape dtype device, RandDTypeIsValid device dtype) => IO (Tensor device dtype shape) randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> (forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b =<< forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (TensorOptions shape dtype device, RandDTypeIsValid device dtype) => IO (Tensor device dtype shape) randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> forall (f :: * -> *) a. Applicative f => a -> f a pure Double layerNormEpsSpec