{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Functional.Normalization where
import GHC.TypeLits (Nat, Symbol, TypeError, type (+), type (-))
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.Prelude (Length, Reverse)
import Torch.GraduallyTyped.Scalar ()
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SelectDims (..), Shape (..), Size (..), dimSize)
import Torch.GraduallyTyped.Tensor.Type (SGetShape (getDims), Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.Internal.Cast (cast5, cast6)
import qualified Torch.Internal.Managed.Native as ATen
import Type.Errors.Pretty (type (%), type (<>))
type family LayerNormImplF (reverseNormalizedDims :: [Dim (Name Symbol) (Size Nat)]) (reverseInputDims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
LayerNormImplF '[] reverseInputDims = reverseInputDims
LayerNormImplF (normalizedDim ': reverseNormalizedDims) (inputDim ': reverseInputDims) = normalizedDim <+> inputDim ': LayerNormImplF reverseNormalizedDims reverseInputDims
LayerNormImplF _ '[] = TypeError LayerNormShapeErrorMessage
type LayerNormShapeErrorMessage =
"Cannot apply the layer norm. "
% "The normalized shape exceeds the input shape."
type family LayerNormWithBiasF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (biasShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
LayerNormWithBiasF 'UncheckedShape _ _ = 'UncheckedShape
LayerNormWithBiasF _ 'UncheckedShape _ = 'UncheckedShape
LayerNormWithBiasF _ _ 'UncheckedShape = 'UncheckedShape
LayerNormWithBiasF ('Shape weightDims) ('Shape biasDims) ('Shape inputDims) = 'Shape (Reverse (LayerNormImplF (Reverse (weightDims <+> biasDims)) (Reverse inputDims)))
layerNormWithBias ::
forall gradient gradient' gradient'' layout layout' layout'' device device' device'' dataType dataType' dataType'' shape shape' shape''.
SGetShape shape =>
Tensor gradient layout device dataType shape ->
Tensor gradient' layout' device' dataType' shape' ->
Double ->
Tensor gradient'' layout'' device'' dataType'' shape'' ->
Tensor
(gradient' <|> gradient' <|> gradient'')
(layout <+> layout' <+> layout'')
(device <+> device' <+> device'')
(dataType <+> dataType' <+> dataType'')
(LayerNormWithBiasF shape shape' shape'')
layerNormWithBias :: forall (gradient :: Gradient RequiresGradient)
(gradient' :: Gradient RequiresGradient)
(gradient'' :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (layout' :: Layout LayoutType)
(layout'' :: Layout LayoutType) (device :: Device (DeviceType Nat))
(device' :: Device (DeviceType Nat))
(device'' :: Device (DeviceType Nat)) (dataType :: DataType DType)
(dataType' :: DataType DType) (dataType'' :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Double
-> Tensor gradient'' layout'' device'' dataType'' shape''
-> Tensor
(gradient' <|> (gradient' <|> gradient''))
(layout <+> (layout' <+> layout''))
(device <+> (device' <+> device''))
(dataType <+> (dataType' <+> dataType''))
(LayerNormWithBiasF shape shape' shape'')
layerNormWithBias Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
bias Double
eps Tensor gradient'' layout'' device'' dataType'' shape''
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
let weightDims :: [Dim String Integer]
weightDims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient layout device dataType shape
weight
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> IO (ForeignPtr Tensor)
ATen.layer_norm_tlttd Tensor gradient'' layout'' device'' dataType'' shape''
input (forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim String Integer]
weightDims) Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
bias Double
eps
type family LayerNormWithoutBiasF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
LayerNormWithoutBiasF 'UncheckedShape _ = 'UncheckedShape
LayerNormWithoutBiasF _ 'UncheckedShape = 'UncheckedShape
LayerNormWithoutBiasF ('Shape weightDims) ('Shape inputDims) = 'Shape (Reverse (LayerNormImplF (Reverse weightDims) (Reverse inputDims)))
type family LayerNormWithoutBiasSelectDimsF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: SelectDims [By Symbol Nat] where
LayerNormWithoutBiasSelectDimsF 'UncheckedShape _ = 'UncheckedSelectDims
LayerNormWithoutBiasSelectDimsF _ 'UncheckedShape = 'UncheckedSelectDims
LayerNormWithoutBiasSelectDimsF ('Shape weightDims) ('Shape inputDims) = 'SelectDims (LayerNormWithoutBiasBysF weightDims inputDims (Length inputDims) 1)
type family LayerNormWithoutBiasBysF (weightDims :: [Dim (Name Symbol) (Size Nat)]) (inputDims :: [Dim (Name Symbol) (Size Nat)]) (inputDimsLength :: Nat) (counter :: Nat) :: [By Symbol Nat] where
LayerNormWithoutBiasBysF '[] _ _ _ = '[]
LayerNormWithoutBiasBysF (_ ': weightDims) (_ ': inputDims) inputDimsLength counter = 'ByIndex (inputDimsLength - counter) ': LayerNormWithoutBiasBysF weightDims inputDims inputDimsLength (counter + 1)
LayerNormWithoutBiasBysF _ '[] inputDimsLength counter =
TypeError
( "Cannot apply the layer norm."
% "The provided weight tensor has more dimensions than the input tensor,"
% ""
% " '" <> counter <> "'"
% ""
% "and"
% ""
% " '" <> inputDimsLength <> "',"
% ""
% "respectively."
)
layerNormWithoutBias ::
forall gradient layout device dataType shape gradient' layout' device' dataType' shape'.
(SGetShape shape, SGetShape shape') =>
Tensor gradient layout device dataType shape ->
Double ->
Tensor gradient' layout' device' dataType' shape' ->
Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
(LayerNormWithoutBiasF shape shape')
layerNormWithoutBias :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, SGetShape shape') =>
Tensor gradient layout device dataType shape
-> Double
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
(LayerNormWithoutBiasF shape shape')
layerNormWithoutBias Tensor gradient layout device dataType shape
weight Double
eps Tensor gradient' layout' device' dataType' shape'
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
let weightDims :: [Dim String Integer]
weightDims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient layout device dataType shape
weight
inputDims :: [Dim String Integer]
inputDims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient' layout' device' dataType' shape'
input
let [Int]
indexes :: [Int] = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Dim String Integer]
inputDims forall a. Num a => a -> a -> a
-) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
1, Int
2 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [Dim String Integer]
weightDims]
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 (Bool
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> CBool
-> IO (ForeignPtr Tensor)
go (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
indexes)) Tensor gradient' layout' device' dataType' shape'
input Tensor gradient layout device dataType shape
weight [Int]
indexes Double
eps (Double
2 :: Double) Bool
True
where
go :: Bool
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> CBool
-> IO (ForeignPtr Tensor)
go Bool
nullIndexes ForeignPtr Tensor
input' ForeignPtr Tensor
weight' ForeignPtr IntArray
indexes ForeignPtr Scalar
eps' ForeignPtr Scalar
exponent' CBool
keepDim = do
ForeignPtr Tensor
squaredInput <- ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.pow_ts ForeignPtr Tensor
input' ForeignPtr Scalar
exponent'
ForeignPtr Tensor
variance <-
if Bool
nullIndexes
then forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr Tensor
squaredInput
else
ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.mean_tlb
ForeignPtr Tensor
squaredInput
ForeignPtr IntArray
indexes
CBool
keepDim
ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.add_ts ForeignPtr Tensor
variance ForeignPtr Scalar
eps'
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.rsqrt_t
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mul_tt ForeignPtr Tensor
input'
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mul_tt ForeignPtr Tensor
weight'