{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.NN.Initialization where

import Control.Monad.Catch (MonadThrow)
import Control.Monad.Indexed ((>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import GHC.Generics (Generic)
import Torch.GraduallyTyped.Internal.TensorOptions (tensorDims)
import Torch.GraduallyTyped.Random (Generator, SGetGeneratorDevice)
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape (Dim (..), dimSize)
import Torch.GraduallyTyped.Tensor.Creation (sRandn)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mulScalar, subScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>))

-- | Note: Identity = linear w/o activation
data ForNonLinearity = ForIdentity | ForSigmoid | ForTanh | ForRelu | ForLeakyRelu Float
  deriving stock (ForNonLinearity -> ForNonLinearity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ForNonLinearity -> ForNonLinearity -> Bool
$c/= :: ForNonLinearity -> ForNonLinearity -> Bool
== :: ForNonLinearity -> ForNonLinearity -> Bool
$c== :: ForNonLinearity -> ForNonLinearity -> Bool
Eq, Eq ForNonLinearity
ForNonLinearity -> ForNonLinearity -> Bool
ForNonLinearity -> ForNonLinearity -> Ordering
ForNonLinearity -> ForNonLinearity -> ForNonLinearity
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
$cmin :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
max :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
$cmax :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
>= :: ForNonLinearity -> ForNonLinearity -> Bool
$c>= :: ForNonLinearity -> ForNonLinearity -> Bool
> :: ForNonLinearity -> ForNonLinearity -> Bool
$c> :: ForNonLinearity -> ForNonLinearity -> Bool
<= :: ForNonLinearity -> ForNonLinearity -> Bool
$c<= :: ForNonLinearity -> ForNonLinearity -> Bool
< :: ForNonLinearity -> ForNonLinearity -> Bool
$c< :: ForNonLinearity -> ForNonLinearity -> Bool
compare :: ForNonLinearity -> ForNonLinearity -> Ordering
$ccompare :: ForNonLinearity -> ForNonLinearity -> Ordering
Ord, Int -> ForNonLinearity -> ShowS
[ForNonLinearity] -> ShowS
ForNonLinearity -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ForNonLinearity] -> ShowS
$cshowList :: [ForNonLinearity] -> ShowS
show :: ForNonLinearity -> [Char]
$cshow :: ForNonLinearity -> [Char]
showsPrec :: Int -> ForNonLinearity -> ShowS
$cshowsPrec :: Int -> ForNonLinearity -> ShowS
Show, forall x. Rep ForNonLinearity x -> ForNonLinearity
forall x. ForNonLinearity -> Rep ForNonLinearity x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ForNonLinearity x -> ForNonLinearity
$cfrom :: forall x. ForNonLinearity -> Rep ForNonLinearity x
Generic)

data FanMode = FanIn | FanOut
  deriving stock (FanMode -> FanMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FanMode -> FanMode -> Bool
$c/= :: FanMode -> FanMode -> Bool
== :: FanMode -> FanMode -> Bool
$c== :: FanMode -> FanMode -> Bool
Eq, Eq FanMode
FanMode -> FanMode -> Bool
FanMode -> FanMode -> Ordering
FanMode -> FanMode -> FanMode
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FanMode -> FanMode -> FanMode
$cmin :: FanMode -> FanMode -> FanMode
max :: FanMode -> FanMode -> FanMode
$cmax :: FanMode -> FanMode -> FanMode
>= :: FanMode -> FanMode -> Bool
$c>= :: FanMode -> FanMode -> Bool
> :: FanMode -> FanMode -> Bool
$c> :: FanMode -> FanMode -> Bool
<= :: FanMode -> FanMode -> Bool
$c<= :: FanMode -> FanMode -> Bool
< :: FanMode -> FanMode -> Bool
$c< :: FanMode -> FanMode -> Bool
compare :: FanMode -> FanMode -> Ordering
$ccompare :: FanMode -> FanMode -> Ordering
Ord, Int -> FanMode -> ShowS
[FanMode] -> ShowS
FanMode -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [FanMode] -> ShowS
$cshowList :: [FanMode] -> ShowS
show :: FanMode -> [Char]
$cshow :: FanMode -> [Char]
showsPrec :: Int -> FanMode -> ShowS
$cshowsPrec :: Int -> FanMode -> ShowS
Show, forall x. Rep FanMode x -> FanMode
forall x. FanMode -> Rep FanMode x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep FanMode x -> FanMode
$cfrom :: forall x. FanMode -> Rep FanMode x
Generic)

errorPrefix :: String
errorPrefix :: [Char]
errorPrefix = [Char]
"Error during tensor initialization. "

-- | Gain scaling value for He initialization
calculateGain :: ForNonLinearity -> Float
calculateGain :: ForNonLinearity -> Float
calculateGain ForNonLinearity
ForIdentity = Float
1
calculateGain ForNonLinearity
ForSigmoid = Float
1
calculateGain ForNonLinearity
ForTanh = Float
5 forall a. Fractional a => a -> a -> a
/ Float
3
calculateGain ForNonLinearity
ForRelu = forall a. Floating a => a -> a
sqrt Float
2
calculateGain (ForLeakyRelu Float
param) = forall a. Floating a => a -> a
sqrt (Float
2 forall a. Fractional a => a -> a -> a
/ (Float
1 forall a. Num a => a -> a -> a
+ Float
param forall a b. (Fractional a, Integral b) => a -> b -> a
^^ (Integer
2 :: Integer)))

-- | Fan-in / Fan-out scaling calculation
calculateFan ::
  [Dim String Integer] ->
  (Integer, Integer)
calculateFan :: [Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
shape
  | Int
dimT forall a. Ord a => a -> a -> Bool
< Int
2 = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
errorPrefix forall a. Semigroup a => a -> a -> a
<> [Char]
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
  | Int
dimT forall a. Eq a => a -> a -> Bool
== Int
2 =
    ( Integer
numInputFmaps,
      Integer
numOutputFmaps
    )
  | Bool
otherwise =
    ( Integer
numInputFmaps forall a. Num a => a -> a -> a
* Integer
receptiveFieldSize,
      Integer
numOutputFmaps forall a. Num a => a -> a -> a
* Integer
receptiveFieldSize
    )
  where
    dimT :: Int
dimT = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Dim [Char] Integer]
shape
    Integer
numOutputFmaps : Integer
numInputFmaps : [Integer]
_ = forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim [Char] Integer]
shape
    receptiveFieldSize :: Integer
receptiveFieldSize = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. [a] -> [a]
tail [Dim [Char] Integer]
shape

-- | Xavier uniform initialization
sXavierUniform ::
  forall gradient layout device dataType shape gain generatorDevice m.
  ( Num gain,
    Floating gain,
    Scalar gain,
    MonadThrow m,
    SGetGeneratorDevice generatorDevice
  ) =>
  TensorSpec gradient layout device dataType shape ->
  gain ->
  Generator generatorDevice ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sXavierUniform :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) gain
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(Num gain, Floating gain, Scalar gain, MonadThrow m,
 SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> gain
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sXavierUniform tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
..} gain
gain =
  let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
      (Integer
fanIn, Integer
fanOut) = [Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims
      std :: gain
std = gain
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (gain
2 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanOut))
      bound :: gain
bound = forall a. Floating a => a -> a
sqrt gain
3 forall a. Num a => a -> a -> a
* gain
std
   in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
          forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ do
            Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
x <- Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (gain
bound forall a. Num a => a -> a -> a
* gain
2)
            Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
x forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`subScalar` gain
bound

-- | Xavier normal initialization
sXavierNormal ::
  forall gradient layout device dataType shape gain generatorDevice m.
  ( Num gain,
    Floating gain,
    Scalar gain,
    MonadThrow m,
    SGetGeneratorDevice generatorDevice
  ) =>
  TensorSpec gradient layout device dataType shape ->
  gain ->
  Generator generatorDevice ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sXavierNormal :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) gain
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(Num gain, Floating gain, Scalar gain, MonadThrow m,
 SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> gain
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sXavierNormal tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} gain
gain =
  let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
      (Integer
fanIn, Integer
fanOut) = [Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims
      std :: gain
std = gain
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (gain
2 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanOut))
   in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
          forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` gain
std

-- | Get fan in or fan out value depending on selected fan mode, used by Kaiming
getter :: forall a. FanMode -> ((a, a) -> a)
getter :: forall a. FanMode -> (a, a) -> a
getter FanMode
FanIn = forall a b. (a, b) -> a
fst
getter FanMode
FanOut = forall a b. (a, b) -> b
snd

-- | Kaiming uniform initialization
sKaimingUniform ::
  forall gradient layout device dataType shape generatorDevice m.
  ( MonadThrow m,
    SGetGeneratorDevice generatorDevice
  ) =>
  TensorSpec gradient layout device dataType shape ->
  FanMode ->
  ForNonLinearity ->
  Generator generatorDevice ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sKaimingUniform :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sKaimingUniform tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} FanMode
fanMode ForNonLinearity
nonLinearity =
  let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
      gain :: Float
gain = ForNonLinearity -> Float
calculateGain ForNonLinearity
nonLinearity
      fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. FanMode -> (a, a) -> a
getter FanMode
fanMode ([Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims)
      std :: Float
std = Float
gain forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue
      bound :: Float
bound = forall a. Floating a => a -> a
sqrt Float
3 forall a. Num a => a -> a -> a
* Float
std
   in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
          forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ do
            Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
x <- Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
bound forall a. Num a => a -> a -> a
* Float
2)
            Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
x forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`subScalar` Float
bound

-- | Kaiming normal initialization
sKaimingNormal ::
  forall gradient layout device dataType shape generatorDevice m.
  ( MonadThrow m,
    SGetGeneratorDevice generatorDevice
  ) =>
  TensorSpec gradient layout device dataType shape ->
  FanMode ->
  ForNonLinearity ->
  Generator generatorDevice ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sKaimingNormal :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sKaimingNormal tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} FanMode
fanMode ForNonLinearity
nonLinearity =
  let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
      gain :: Float
gain = ForNonLinearity -> Float
calculateGain ForNonLinearity
nonLinearity
      fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. FanMode -> (a, a) -> a
getter FanMode
fanMode ([Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims)
      std :: Float
std = Float
gain forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue
   in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
          forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ Tensor
  gradient
  layout
  (Unify (Device (DeviceType Nat)) device generatorDevice)
  dataType
  shape
initTensor forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` Float
std