{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.NN.Normalization where

import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data LayerNormSpec (normalizedShape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where
  LayerNormSpec ::
    forall normalizedShape dtype device.
    {forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNormSpec normalizedShape dtype device -> Double
layerNormEpsSpec :: Double} ->
    LayerNormSpec normalizedShape dtype device
  deriving (Int -> LayerNormSpec normalizedShape dtype device -> ShowS
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LayerNormSpec normalizedShape dtype device -> ShowS
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LayerNormSpec normalizedShape dtype device] -> ShowS
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNormSpec normalizedShape dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LayerNormSpec normalizedShape dtype device] -> ShowS
$cshowList :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LayerNormSpec normalizedShape dtype device] -> ShowS
show :: LayerNormSpec normalizedShape dtype device -> String
$cshow :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNormSpec normalizedShape dtype device -> String
showsPrec :: Int -> LayerNormSpec normalizedShape dtype device -> ShowS
$cshowsPrec :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LayerNormSpec normalizedShape dtype device -> ShowS
Show, LayerNormSpec normalizedShape dtype device
-> LayerNormSpec normalizedShape dtype device -> Bool
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNormSpec normalizedShape dtype device
-> LayerNormSpec normalizedShape dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LayerNormSpec normalizedShape dtype device
-> LayerNormSpec normalizedShape dtype device -> Bool
$c/= :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNormSpec normalizedShape dtype device
-> LayerNormSpec normalizedShape dtype device -> Bool
== :: LayerNormSpec normalizedShape dtype device
-> LayerNormSpec normalizedShape dtype device -> Bool
$c== :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNormSpec normalizedShape dtype device
-> LayerNormSpec normalizedShape dtype device -> Bool
Eq)

data LayerNorm (normalizedShape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where
  LayerNorm ::
    { forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (normalizedShape :: [Nat]).
LayerNorm normalizedShape dtype device
-> Parameter device dtype normalizedShape
layerNormWeight :: Parameter device dtype normalizedShape,
      forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (normalizedShape :: [Nat]).
LayerNorm normalizedShape dtype device
-> Parameter device dtype normalizedShape
layerNormBias :: Parameter device dtype normalizedShape,
      forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (normalizedShape :: [Nat]).
LayerNorm normalizedShape dtype device -> Double
layerNormEps :: Double
    } ->
    LayerNorm normalizedShape dtype device
  deriving (Int -> LayerNorm normalizedShape dtype device -> ShowS
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LayerNorm normalizedShape dtype device -> ShowS
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LayerNorm normalizedShape dtype device] -> ShowS
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNorm normalizedShape dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LayerNorm normalizedShape dtype device] -> ShowS
$cshowList :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LayerNorm normalizedShape dtype device] -> ShowS
show :: LayerNorm normalizedShape dtype device -> String
$cshow :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNorm normalizedShape dtype device -> String
showsPrec :: Int -> LayerNorm normalizedShape dtype device -> ShowS
$cshowsPrec :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LayerNorm normalizedShape dtype device -> ShowS
Show, forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (LayerNorm normalizedShape dtype device) x
-> LayerNorm normalizedShape dtype device
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
LayerNorm normalizedShape dtype device
-> Rep (LayerNorm normalizedShape dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (LayerNorm normalizedShape dtype device) x
-> LayerNorm normalizedShape dtype device
$cfrom :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
LayerNorm normalizedShape dtype device
-> Rep (LayerNorm normalizedShape dtype device) x
Generic, forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNorm normalizedShape dtype device
-> HList (Parameters (LayerNorm normalizedShape dtype device))
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNorm normalizedShape dtype device
-> HList (Parameters (LayerNorm normalizedShape dtype device))
-> LayerNorm normalizedShape dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: LayerNorm normalizedShape dtype device
-> HList (Parameters (LayerNorm normalizedShape dtype device))
-> LayerNorm normalizedShape dtype device
$creplaceParameters :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNorm normalizedShape dtype device
-> HList (Parameters (LayerNorm normalizedShape dtype device))
-> LayerNorm normalizedShape dtype device
flattenParameters :: LayerNorm normalizedShape dtype device
-> HList (Parameters (LayerNorm normalizedShape dtype device))
$cflattenParameters :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNorm normalizedShape dtype device
-> HList (Parameters (LayerNorm normalizedShape dtype device))
Parameterized)

layerNormForward ::
  forall normalizedShape shape dtype device.
  ( IsSuffixOf normalizedShape shape,
    KnownShape normalizedShape
  ) =>
  LayerNorm normalizedShape dtype device ->
  Tensor device dtype shape ->
  Tensor device dtype shape
layerNormForward :: forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(IsSuffixOf normalizedShape shape, KnownShape normalizedShape) =>
LayerNorm normalizedShape dtype device
-> Tensor device dtype shape -> Tensor device dtype shape
layerNormForward LayerNorm {Double
Parameter device dtype normalizedShape
layerNormEps :: Double
layerNormBias :: Parameter device dtype normalizedShape
layerNormWeight :: Parameter device dtype normalizedShape
layerNormEps :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (normalizedShape :: [Nat]).
LayerNorm normalizedShape dtype device -> Double
layerNormBias :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (normalizedShape :: [Nat]).
LayerNorm normalizedShape dtype device
-> Parameter device dtype normalizedShape
layerNormWeight :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (normalizedShape :: [Nat]).
LayerNorm normalizedShape dtype device
-> Parameter device dtype normalizedShape
..} =
  forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape normalizedShape, IsSuffixOf normalizedShape shape) =>
Tensor device dtype normalizedShape
-> Tensor device dtype normalizedShape
-> Double
-> Tensor device dtype shape
-> Tensor device dtype shape
layerNorm @normalizedShape
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype normalizedShape
layerNormWeight)
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype normalizedShape
layerNormBias)
    Double
layerNormEps

instance
  ( IsSuffixOf normalizedShape shape,
    KnownShape normalizedShape
  ) =>
  HasForward (LayerNorm normalizedShape dtype device) (Tensor device dtype shape) (Tensor device dtype shape)
  where
  forward :: LayerNorm normalizedShape dtype device
-> Tensor device dtype shape -> Tensor device dtype shape
forward = forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(IsSuffixOf normalizedShape shape, KnownShape normalizedShape) =>
LayerNorm normalizedShape dtype device
-> Tensor device dtype shape -> Tensor device dtype shape
layerNormForward
  forwardStoch :: LayerNorm normalizedShape dtype device
-> Tensor device dtype shape -> IO (Tensor device dtype shape)
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( TensorOptions normalizedShape dtype device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (LayerNormSpec normalizedShape dtype device)
    (LayerNorm normalizedShape dtype device)
  where
  sample :: LayerNormSpec normalizedShape dtype device
-> IO (LayerNorm normalizedShape dtype device)
sample LayerNormSpec {Double
layerNormEpsSpec :: Double
layerNormEpsSpec :: forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LayerNormSpec normalizedShape dtype device -> Double
..} =
    forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (normalizedShape :: [Nat]).
Parameter device dtype normalizedShape
-> Parameter device dtype normalizedShape
-> Double
-> LayerNorm normalizedShape dtype device
LayerNorm
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
layerNormEpsSpec