module Torch.Initializers where
import Torch.Functional hiding (sqrt)
import Torch.Tensor
import Torch.TensorFactories
data NonLinearity = Identity | Sigmoid | Tanh | Relu | LeakyRelu Float
data FanMode = FanIn | FanOut
newtype Shape = Shape [Int]
calculateGain :: NonLinearity -> Float
calculateGain :: NonLinearity -> Float
calculateGain NonLinearity
Identity = Float
1.0
calculateGain NonLinearity
Sigmoid = Float
1.0
calculateGain NonLinearity
Tanh = Float
5.0 forall a. Fractional a => a -> a -> a
/ Float
3
calculateGain NonLinearity
Relu = forall a. Floating a => a -> a
sqrt Float
2.0
calculateGain (LeakyRelu Float
param) = forall a. Floating a => a -> a
sqrt (Float
2.0 forall a. Fractional a => a -> a -> a
/ (Float
1.0 forall a. Num a => a -> a -> a
+ Float
param forall a b. (Fractional a, Integral b) => a -> b -> a
^^ Integer
2))
calculateFan :: [Int] -> (Int, Int)
calculateFan :: [Int] -> (Int, Int)
calculateFan [Int]
shape
| Int
dimT forall a. Ord a => a -> a -> Bool
< Int
2 = forall a. HasCallStack => [Char] -> a
error [Char]
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
| Int
dimT forall a. Eq a => a -> a -> Bool
== Int
2 = ([Int]
shape forall a. [a] -> Int -> a
!! Int
1, forall a. [a] -> a
head [Int]
shape)
| Bool
otherwise = (Int
numInputFmaps forall a. Num a => a -> a -> a
* Int
receptiveFieldSize, Int
numOutputFmaps forall a. Num a => a -> a -> a
* Int
receptiveFieldSize)
where
dimT :: Int
dimT = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape
numInputFmaps :: Int
numInputFmaps = [Int]
shape forall a. [a] -> Int -> a
!! Int
1
numOutputFmaps :: Int
numOutputFmaps = forall a. [a] -> a
head [Int]
shape
receptiveFieldSize :: Int
receptiveFieldSize = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [Int]
shape
xavierUniform :: Float -> [Int] -> IO Tensor
xavierUniform :: Float -> [Int] -> IO Tensor
xavierUniform Float
gain [Int]
shape = do
Tensor
init <- [Int] -> IO Tensor
randIO' [Int]
shape
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
where
(Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
std :: Float
std = Float
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Float
2.0 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))
bound :: Float
bound = forall a. Floating a => a -> a
sqrt Float
3.0 forall a. Num a => a -> a -> a
* Float
std
xavierNormal :: Float -> [Int] -> IO Tensor
xavierNormal :: Float -> [Int] -> IO Tensor
xavierNormal Float
gain [Int]
shape = do
Tensor
init <- [Int] -> IO Tensor
randnIO' [Int]
shape
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
std Tensor
init
where
(Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
std :: Float
std = Float
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Float
2.0 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))
getter :: FanMode -> ((Int, Int) -> Int)
getter :: FanMode -> (Int, Int) -> Int
getter FanMode
FanIn = forall a b. (a, b) -> a
fst
getter FanMode
FanOut = forall a b. (a, b) -> b
snd
kaimingUniform :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform FanMode
mode NonLinearity
nonlinearity [Int]
shape = do
Tensor
init <- [Int] -> IO Tensor
randIO' [Int]
shape
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
where
fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ FanMode -> (Int, Int) -> Int
getter FanMode
mode ([Int] -> (Int, Int)
calculateFan [Int]
shape)
std :: Float
std = NonLinearity -> Float
calculateGain NonLinearity
nonlinearity forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue
bound :: Float
bound = forall a. Floating a => a -> a
sqrt Float
3.0 forall a. Num a => a -> a -> a
* Float
std
kaimingNormal :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal FanMode
mode NonLinearity
nonlinearity [Int]
shape = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
std forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randnIO' [Int]
shape
where
fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ FanMode -> (Int, Int) -> Int
getter FanMode
mode ([Int] -> (Int, Int)
calculateFan [Int]
shape)
std :: Float
std = NonLinearity -> Float
calculateGain NonLinearity
nonlinearity forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue
kaimingFC :: [Int] -> IO (Tensor, Tensor)
kaimingFC :: [Int] -> IO (Tensor, Tensor)
kaimingFC [Int]
weightShape = do
Tensor
weight <- [Int] -> IO Tensor
kaimingUniform' [Int]
weightShape
Tensor
biasInit <- [Int] -> IO Tensor
randIO' [Int]
biasShape
let bias :: Tensor
bias = forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
biasInit
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
weight, Tensor
bias)
where
(Int
fanIn, Int
_) = [Int] -> (Int, Int)
calculateFan [Int]
weightShape
bound :: Float
bound = Float
1.0 forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
fanIn) :: Float
biasShape :: [Int]
biasShape = [forall a. [a] -> a
head [Int]
weightShape]
kaimingUniform' :: [Int] -> IO Tensor
kaimingUniform' :: [Int] -> IO Tensor
kaimingUniform' = FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform FanMode
FanIn (Float -> NonLinearity
LeakyRelu Float
0.0)
kaimingNormal' :: [Int] -> IO Tensor
kaimingNormal' :: [Int] -> IO Tensor
kaimingNormal' = FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal FanMode
FanIn (Float -> NonLinearity
LeakyRelu Float
0.0)
xavierUniform' :: [Int] -> IO Tensor
xavierUniform' :: [Int] -> IO Tensor
xavierUniform' = Float -> [Int] -> IO Tensor
xavierUniform Float
1.0
xavierNormal' :: [Int] -> IO Tensor
xavierNormal' :: [Int] -> IO Tensor
xavierNormal' = Float -> [Int] -> IO Tensor
xavierNormal Float
1.0