{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-}
module Torch.GraduallyTyped.NN.Initialization where
import Control.Monad.Catch (MonadThrow)
import Control.Monad.Indexed ((>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import GHC.Generics (Generic)
import Torch.GraduallyTyped.Internal.TensorOptions (tensorDims)
import Torch.GraduallyTyped.Random (Generator, SGetGeneratorDevice)
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape (Dim (..), dimSize)
import Torch.GraduallyTyped.Tensor.Creation (sRandn)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mulScalar, subScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>))
data ForNonLinearity = ForIdentity | ForSigmoid | ForTanh | ForRelu | ForLeakyRelu Float
deriving stock (ForNonLinearity -> ForNonLinearity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ForNonLinearity -> ForNonLinearity -> Bool
$c/= :: ForNonLinearity -> ForNonLinearity -> Bool
== :: ForNonLinearity -> ForNonLinearity -> Bool
$c== :: ForNonLinearity -> ForNonLinearity -> Bool
Eq, Eq ForNonLinearity
ForNonLinearity -> ForNonLinearity -> Bool
ForNonLinearity -> ForNonLinearity -> Ordering
ForNonLinearity -> ForNonLinearity -> ForNonLinearity
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
$cmin :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
max :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
$cmax :: ForNonLinearity -> ForNonLinearity -> ForNonLinearity
>= :: ForNonLinearity -> ForNonLinearity -> Bool
$c>= :: ForNonLinearity -> ForNonLinearity -> Bool
> :: ForNonLinearity -> ForNonLinearity -> Bool
$c> :: ForNonLinearity -> ForNonLinearity -> Bool
<= :: ForNonLinearity -> ForNonLinearity -> Bool
$c<= :: ForNonLinearity -> ForNonLinearity -> Bool
< :: ForNonLinearity -> ForNonLinearity -> Bool
$c< :: ForNonLinearity -> ForNonLinearity -> Bool
compare :: ForNonLinearity -> ForNonLinearity -> Ordering
$ccompare :: ForNonLinearity -> ForNonLinearity -> Ordering
Ord, Int -> ForNonLinearity -> ShowS
[ForNonLinearity] -> ShowS
ForNonLinearity -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ForNonLinearity] -> ShowS
$cshowList :: [ForNonLinearity] -> ShowS
show :: ForNonLinearity -> [Char]
$cshow :: ForNonLinearity -> [Char]
showsPrec :: Int -> ForNonLinearity -> ShowS
$cshowsPrec :: Int -> ForNonLinearity -> ShowS
Show, forall x. Rep ForNonLinearity x -> ForNonLinearity
forall x. ForNonLinearity -> Rep ForNonLinearity x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ForNonLinearity x -> ForNonLinearity
$cfrom :: forall x. ForNonLinearity -> Rep ForNonLinearity x
Generic)
data FanMode = FanIn | FanOut
deriving stock (FanMode -> FanMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FanMode -> FanMode -> Bool
$c/= :: FanMode -> FanMode -> Bool
== :: FanMode -> FanMode -> Bool
$c== :: FanMode -> FanMode -> Bool
Eq, Eq FanMode
FanMode -> FanMode -> Bool
FanMode -> FanMode -> Ordering
FanMode -> FanMode -> FanMode
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FanMode -> FanMode -> FanMode
$cmin :: FanMode -> FanMode -> FanMode
max :: FanMode -> FanMode -> FanMode
$cmax :: FanMode -> FanMode -> FanMode
>= :: FanMode -> FanMode -> Bool
$c>= :: FanMode -> FanMode -> Bool
> :: FanMode -> FanMode -> Bool
$c> :: FanMode -> FanMode -> Bool
<= :: FanMode -> FanMode -> Bool
$c<= :: FanMode -> FanMode -> Bool
< :: FanMode -> FanMode -> Bool
$c< :: FanMode -> FanMode -> Bool
compare :: FanMode -> FanMode -> Ordering
$ccompare :: FanMode -> FanMode -> Ordering
Ord, Int -> FanMode -> ShowS
[FanMode] -> ShowS
FanMode -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [FanMode] -> ShowS
$cshowList :: [FanMode] -> ShowS
show :: FanMode -> [Char]
$cshow :: FanMode -> [Char]
showsPrec :: Int -> FanMode -> ShowS
$cshowsPrec :: Int -> FanMode -> ShowS
Show, forall x. Rep FanMode x -> FanMode
forall x. FanMode -> Rep FanMode x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep FanMode x -> FanMode
$cfrom :: forall x. FanMode -> Rep FanMode x
Generic)
errorPrefix :: String
errorPrefix :: [Char]
errorPrefix = [Char]
"Error during tensor initialization. "
calculateGain :: ForNonLinearity -> Float
calculateGain :: ForNonLinearity -> Float
calculateGain ForNonLinearity
ForIdentity = Float
1
calculateGain ForNonLinearity
ForSigmoid = Float
1
calculateGain ForNonLinearity
ForTanh = Float
5 forall a. Fractional a => a -> a -> a
/ Float
3
calculateGain ForNonLinearity
ForRelu = forall a. Floating a => a -> a
sqrt Float
2
calculateGain (ForLeakyRelu Float
param) = forall a. Floating a => a -> a
sqrt (Float
2 forall a. Fractional a => a -> a -> a
/ (Float
1 forall a. Num a => a -> a -> a
+ Float
param forall a b. (Fractional a, Integral b) => a -> b -> a
^^ (Integer
2 :: Integer)))
calculateFan ::
[Dim String Integer] ->
(Integer, Integer)
calculateFan :: [Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
shape
| Int
dimT forall a. Ord a => a -> a -> Bool
< Int
2 = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
errorPrefix forall a. Semigroup a => a -> a -> a
<> [Char]
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
| Int
dimT forall a. Eq a => a -> a -> Bool
== Int
2 =
( Integer
numInputFmaps,
Integer
numOutputFmaps
)
| Bool
otherwise =
( Integer
numInputFmaps forall a. Num a => a -> a -> a
* Integer
receptiveFieldSize,
Integer
numOutputFmaps forall a. Num a => a -> a -> a
* Integer
receptiveFieldSize
)
where
dimT :: Int
dimT = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Dim [Char] Integer]
shape
Integer
numOutputFmaps : Integer
numInputFmaps : [Integer]
_ = forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim [Char] Integer]
shape
receptiveFieldSize :: Integer
receptiveFieldSize = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. [a] -> [a]
tail [Dim [Char] Integer]
shape
sXavierUniform ::
forall gradient layout device dataType shape gain generatorDevice m.
( Num gain,
Floating gain,
Scalar gain,
MonadThrow m,
SGetGeneratorDevice generatorDevice
) =>
TensorSpec gradient layout device dataType shape ->
gain ->
Generator generatorDevice ->
m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sXavierUniform :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) gain
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(Num gain, Floating gain, Scalar gain, MonadThrow m,
SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> gain
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sXavierUniform tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
..} gain
gain =
let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
(Integer
fanIn, Integer
fanOut) = [Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims
std :: gain
std = gain
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (gain
2 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanOut))
bound :: gain
bound = forall a. Floating a => a -> a
sqrt gain
3 forall a. Num a => a -> a -> a
* gain
std
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ do
Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
x <- Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (gain
bound forall a. Num a => a -> a -> a
* gain
2)
Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
x forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`subScalar` gain
bound
sXavierNormal ::
forall gradient layout device dataType shape gain generatorDevice m.
( Num gain,
Floating gain,
Scalar gain,
MonadThrow m,
SGetGeneratorDevice generatorDevice
) =>
TensorSpec gradient layout device dataType shape ->
gain ->
Generator generatorDevice ->
m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sXavierNormal :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) gain
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(Num gain, Floating gain, Scalar gain, MonadThrow m,
SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> gain
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sXavierNormal tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} gain
gain =
let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
(Integer
fanIn, Integer
fanOut) = [Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims
std :: gain
std = gain
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (gain
2 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fanOut))
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` gain
std
getter :: forall a. FanMode -> ((a, a) -> a)
getter :: forall a. FanMode -> (a, a) -> a
getter FanMode
FanIn = forall a b. (a, b) -> a
fst
getter FanMode
FanOut = forall a b. (a, b) -> b
snd
sKaimingUniform ::
forall gradient layout device dataType shape generatorDevice m.
( MonadThrow m,
SGetGeneratorDevice generatorDevice
) =>
TensorSpec gradient layout device dataType shape ->
FanMode ->
ForNonLinearity ->
Generator generatorDevice ->
m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sKaimingUniform :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sKaimingUniform tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} FanMode
fanMode ForNonLinearity
nonLinearity =
let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
gain :: Float
gain = ForNonLinearity -> Float
calculateGain ForNonLinearity
nonLinearity
fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. FanMode -> (a, a) -> a
getter FanMode
fanMode ([Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims)
std :: Float
std = Float
gain forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue
bound :: Float
bound = forall a. Floating a => a -> a
sqrt Float
3 forall a. Num a => a -> a -> a
* Float
std
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ do
Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
x <- Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
bound forall a. Num a => a -> a -> a
* Float
2)
Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
x forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`subScalar` Float
bound
sKaimingNormal ::
forall gradient layout device dataType shape generatorDevice m.
( MonadThrow m,
SGetGeneratorDevice generatorDevice
) =>
TensorSpec gradient layout device dataType shape ->
FanMode ->
ForNonLinearity ->
Generator generatorDevice ->
m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sKaimingNormal :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sKaimingNormal tensorSpec :: TensorSpec gradient layout device dataType shape
tensorSpec@TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} FanMode
fanMode ForNonLinearity
nonLinearity =
let dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
gain :: Float
gain = ForNonLinearity -> Float
calculateGain ForNonLinearity
nonLinearity
fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. FanMode -> (a, a) -> a
getter FanMode
fanMode ([Dim [Char] Integer] -> (Integer, Integer)
calculateFan [Dim [Char] Integer]
dims)
std :: Float
std = Float
gain forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tensorSpec)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= \Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor -> forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift forall a b. (a -> b) -> a -> b
$ Tensor
gradient
layout
(Unify (Device (DeviceType Nat)) device generatorDevice)
dataType
shape
initTensor forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` Float
std