{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Functional.Normalization where

import GHC.TypeLits (Nat, Symbol, TypeError, type (+), type (-))
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.Prelude (Length, Reverse)
import Torch.GraduallyTyped.Scalar ()
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SelectDims (..), Shape (..), Size (..), dimSize)
import Torch.GraduallyTyped.Tensor.Type (SGetShape (getDims), Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.Internal.Cast (cast5, cast6)
import qualified Torch.Internal.Managed.Native as ATen
import Type.Errors.Pretty (type (%), type (<>))

type family LayerNormImplF (reverseNormalizedDims :: [Dim (Name Symbol) (Size Nat)]) (reverseInputDims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
  LayerNormImplF '[] reverseInputDims = reverseInputDims
  LayerNormImplF (normalizedDim ': reverseNormalizedDims) (inputDim ': reverseInputDims) = normalizedDim <+> inputDim ': LayerNormImplF reverseNormalizedDims reverseInputDims
  LayerNormImplF _ '[] = TypeError LayerNormShapeErrorMessage

type LayerNormShapeErrorMessage =
  "Cannot apply the layer norm. "
    % "The normalized shape exceeds the input shape."

type family LayerNormWithBiasF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (biasShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  LayerNormWithBiasF 'UncheckedShape _ _ = 'UncheckedShape
  LayerNormWithBiasF _ 'UncheckedShape _ = 'UncheckedShape
  LayerNormWithBiasF _ _ 'UncheckedShape = 'UncheckedShape
  LayerNormWithBiasF ('Shape weightDims) ('Shape biasDims) ('Shape inputDims) = 'Shape (Reverse (LayerNormImplF (Reverse (weightDims <+> biasDims)) (Reverse inputDims)))

layerNormWithBias ::
  forall gradient gradient' gradient'' layout layout' layout'' device device' device'' dataType dataType' dataType'' shape shape' shape''.
  SGetShape shape =>
  -- | weight
  Tensor gradient layout device dataType shape ->
  -- | bias
  Tensor gradient' layout' device' dataType' shape' ->
  -- | eps
  Double ->
  -- | input
  Tensor gradient'' layout'' device'' dataType'' shape'' ->
  -- | output
  Tensor
    (gradient' <|> gradient' <|> gradient'')
    (layout <+> layout' <+> layout'')
    (device <+> device' <+> device'')
    (dataType <+> dataType' <+> dataType'')
    (LayerNormWithBiasF shape shape' shape'')
layerNormWithBias :: forall (gradient :: Gradient RequiresGradient)
       (gradient' :: Gradient RequiresGradient)
       (gradient'' :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (layout' :: Layout LayoutType)
       (layout'' :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat))
       (device'' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (dataType' :: DataType DType) (dataType'' :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Double
-> Tensor gradient'' layout'' device'' dataType'' shape''
-> Tensor
     (gradient' <|> (gradient' <|> gradient''))
     (layout <+> (layout' <+> layout''))
     (device <+> (device' <+> device''))
     (dataType <+> (dataType' <+> dataType''))
     (LayerNormWithBiasF shape shape' shape'')
layerNormWithBias Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
bias Double
eps Tensor gradient'' layout'' device'' dataType'' shape''
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  let weightDims :: [Dim String Integer]
weightDims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient layout device dataType shape
weight
  forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> IO (ForeignPtr Tensor)
ATen.layer_norm_tlttd Tensor gradient'' layout'' device'' dataType'' shape''
input (forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim String Integer]
weightDims) Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
bias Double
eps

type family LayerNormWithoutBiasF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  LayerNormWithoutBiasF 'UncheckedShape _ = 'UncheckedShape
  LayerNormWithoutBiasF _ 'UncheckedShape = 'UncheckedShape
  LayerNormWithoutBiasF ('Shape weightDims) ('Shape inputDims) = 'Shape (Reverse (LayerNormImplF (Reverse weightDims) (Reverse inputDims)))

type family LayerNormWithoutBiasSelectDimsF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: SelectDims [By Symbol Nat] where
  LayerNormWithoutBiasSelectDimsF 'UncheckedShape _ = 'UncheckedSelectDims
  LayerNormWithoutBiasSelectDimsF _ 'UncheckedShape = 'UncheckedSelectDims
  LayerNormWithoutBiasSelectDimsF ('Shape weightDims) ('Shape inputDims) = 'SelectDims (LayerNormWithoutBiasBysF weightDims inputDims (Length inputDims) 1)

type family LayerNormWithoutBiasBysF (weightDims :: [Dim (Name Symbol) (Size Nat)]) (inputDims :: [Dim (Name Symbol) (Size Nat)]) (inputDimsLength :: Nat) (counter :: Nat) :: [By Symbol Nat] where
  LayerNormWithoutBiasBysF '[] _ _ _ = '[]
  LayerNormWithoutBiasBysF (_ ': weightDims) (_ ': inputDims) inputDimsLength counter = 'ByIndex (inputDimsLength - counter) ': LayerNormWithoutBiasBysF weightDims inputDims inputDimsLength (counter + 1)
  LayerNormWithoutBiasBysF _ '[] inputDimsLength counter =
    TypeError
      ( "Cannot apply the layer norm."
          % "The provided weight tensor has more dimensions than the input tensor,"
          % ""
          % "    '" <> counter <> "'"
          % ""
          % "and"
          % ""
          % "    '" <> inputDimsLength <> "',"
          % ""
          % "respectively."
      )

-- | T5-style layer norm
layerNormWithoutBias ::
  forall gradient layout device dataType shape gradient' layout' device' dataType' shape'.
  (SGetShape shape, SGetShape shape') =>
  -- | weight
  Tensor gradient layout device dataType shape ->
  -- | eps
  Double ->
  -- | input
  Tensor gradient' layout' device' dataType' shape' ->
  -- | output
  Tensor
    (gradient <|> gradient')
    (layout <+> layout')
    (device <+> device')
    (dataType <+> dataType')
    (LayerNormWithoutBiasF shape shape')
layerNormWithoutBias :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, SGetShape shape') =>
Tensor gradient layout device dataType shape
-> Double
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
     (gradient <|> gradient')
     (layout <+> layout')
     (device <+> device')
     (dataType <+> dataType')
     (LayerNormWithoutBiasF shape shape')
layerNormWithoutBias Tensor gradient layout device dataType shape
weight Double
eps Tensor gradient' layout' device' dataType' shape'
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  let weightDims :: [Dim String Integer]
weightDims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient layout device dataType shape
weight
      inputDims :: [Dim String Integer]
inputDims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient' layout' device' dataType' shape'
input
  let [Int]
indexes :: [Int] = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Dim String Integer]
inputDims forall a. Num a => a -> a -> a
-) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
1, Int
2 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [Dim String Integer]
weightDims]
  forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 (Bool
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> CBool
-> IO (ForeignPtr Tensor)
go (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
indexes)) Tensor gradient' layout' device' dataType' shape'
input Tensor gradient layout device dataType shape
weight [Int]
indexes Double
eps (Double
2 :: Double) Bool
True
  where
    go :: Bool
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> CBool
-> IO (ForeignPtr Tensor)
go Bool
nullIndexes ForeignPtr Tensor
input' ForeignPtr Tensor
weight' ForeignPtr IntArray
indexes ForeignPtr Scalar
eps' ForeignPtr Scalar
exponent' CBool
keepDim = do
      ForeignPtr Tensor
squaredInput <- ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.pow_ts ForeignPtr Tensor
input' ForeignPtr Scalar
exponent'
      ForeignPtr Tensor
variance <-
        if Bool
nullIndexes
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr Tensor
squaredInput
          else
            ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.mean_tlb
              ForeignPtr Tensor
squaredInput
              ForeignPtr IntArray
indexes
              CBool
keepDim
      ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.add_ts ForeignPtr Tensor
variance ForeignPtr Scalar
eps'
        forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.rsqrt_t
        forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mul_tt ForeignPtr Tensor
input'
        forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mul_tt ForeignPtr Tensor
weight'