module Torch.Initializers where

import Torch.Functional hiding (sqrt)
import Torch.Tensor
import Torch.TensorFactories

-- Note: Identity = linear w/o activation
data NonLinearity = Identity | Sigmoid | Tanh | Relu | LeakyRelu Float

data FanMode = FanIn | FanOut

newtype Shape = Shape [Int]

-- | Gain scaling value for He initialization
calculateGain :: NonLinearity -> Float
calculateGain :: NonLinearity -> Float
calculateGain NonLinearity
Identity = Float
1.0
calculateGain NonLinearity
Sigmoid = Float
1.0
calculateGain NonLinearity
Tanh = Float
5.0 forall a. Fractional a => a -> a -> a
/ Float
3
calculateGain NonLinearity
Relu = forall a. Floating a => a -> a
sqrt Float
2.0
calculateGain (LeakyRelu Float
param) = forall a. Floating a => a -> a
sqrt (Float
2.0 forall a. Fractional a => a -> a -> a
/ (Float
1.0 forall a. Num a => a -> a -> a
+ Float
param forall a b. (Fractional a, Integral b) => a -> b -> a
^^ Integer
2))

-- | Fan-in / Fan-out scaling calculation
calculateFan :: [Int] -> (Int, Int)
calculateFan :: [Int] -> (Int, Int)
calculateFan [Int]
shape
  | Int
dimT forall a. Ord a => a -> a -> Bool
< Int
2 = forall a. HasCallStack => [Char] -> a
error [Char]
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
  | Int
dimT forall a. Eq a => a -> a -> Bool
== Int
2 = ([Int]
shape forall a. [a] -> Int -> a
!! Int
1, forall a. [a] -> a
head [Int]
shape)
  | Bool
otherwise = (Int
numInputFmaps forall a. Num a => a -> a -> a
* Int
receptiveFieldSize, Int
numOutputFmaps forall a. Num a => a -> a -> a
* Int
receptiveFieldSize)
  where
    dimT :: Int
dimT = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape
    numInputFmaps :: Int
numInputFmaps = [Int]
shape forall a. [a] -> Int -> a
!! Int
1 -- size t 1
    numOutputFmaps :: Int
numOutputFmaps = forall a. [a] -> a
head [Int]
shape -- size t 0
    receptiveFieldSize :: Int
receptiveFieldSize = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [Int]
shape

-- | Xavier Initialization - Uniform
xavierUniform :: Float -> [Int] -> IO Tensor
xavierUniform :: Float -> [Int] -> IO Tensor
xavierUniform Float
gain [Int]
shape = do
  Tensor
init <- [Int] -> IO Tensor
randIO' [Int]
shape
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
  where
    (Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
    std :: Float
std = Float
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Float
2.0 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))
    bound :: Float
bound = forall a. Floating a => a -> a
sqrt Float
3.0 forall a. Num a => a -> a -> a
* Float
std

-- | Xavier Initialization - Normal
xavierNormal :: Float -> [Int] -> IO Tensor
xavierNormal :: Float -> [Int] -> IO Tensor
xavierNormal Float
gain [Int]
shape = do
  Tensor
init <- [Int] -> IO Tensor
randnIO' [Int]
shape
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
std Tensor
init
  where
    (Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
    std :: Float
std = Float
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Float
2.0 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))

-- | Get fan in or fan out value depending on selected fan mode, used by Kaiming
getter :: FanMode -> ((Int, Int) -> Int)
getter :: FanMode -> (Int, Int) -> Int
getter FanMode
FanIn = forall a b. (a, b) -> a
fst
getter FanMode
FanOut = forall a b. (a, b) -> b
snd

-- | Kaiming Initialization - Uniform
kaimingUniform :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform FanMode
mode NonLinearity
nonlinearity [Int]
shape = do
  Tensor
init <- [Int] -> IO Tensor
randIO' [Int]
shape
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
  where
    fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ FanMode -> (Int, Int) -> Int
getter FanMode
mode ([Int] -> (Int, Int)
calculateFan [Int]
shape)
    std :: Float
std = NonLinearity -> Float
calculateGain NonLinearity
nonlinearity forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue
    bound :: Float
bound = forall a. Floating a => a -> a
sqrt Float
3.0 forall a. Num a => a -> a -> a
* Float
std

-- | Kaiming Initialization - Normal
kaimingNormal :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal FanMode
mode NonLinearity
nonlinearity [Int]
shape = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
std forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randnIO' [Int]
shape
  where
    fanValue :: Float
fanValue = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ FanMode -> (Int, Int) -> Int
getter FanMode
mode ([Int] -> (Int, Int)
calculateFan [Int]
shape)
    std :: Float
std = NonLinearity -> Float
calculateGain NonLinearity
nonlinearity forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Float
fanValue

-- | Handle weights + bias
-- based on https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L79
kaimingFC :: [Int] -> IO (Tensor, Tensor)
kaimingFC :: [Int] -> IO (Tensor, Tensor)
kaimingFC [Int]
weightShape = do
  Tensor
weight <- [Int] -> IO Tensor
kaimingUniform' [Int]
weightShape
  Tensor
biasInit <- [Int] -> IO Tensor
randIO' [Int]
biasShape
  let bias :: Tensor
bias = forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
biasInit
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
weight, Tensor
bias)
  where
    (Int
fanIn, Int
_) = [Int] -> (Int, Int)
calculateFan [Int]
weightShape
    bound :: Float
bound = Float
1.0 forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
fanIn) :: Float
    biasShape :: [Int]
biasShape = [forall a. [a] -> a
head [Int]
weightShape]

{- PyTorch defaults -}

kaimingUniform' :: [Int] -> IO Tensor
kaimingUniform' :: [Int] -> IO Tensor
kaimingUniform' = FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform FanMode
FanIn (Float -> NonLinearity
LeakyRelu Float
0.0)

kaimingNormal' :: [Int] -> IO Tensor
kaimingNormal' :: [Int] -> IO Tensor
kaimingNormal' = FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal FanMode
FanIn (Float -> NonLinearity
LeakyRelu Float
0.0)

xavierUniform' :: [Int] -> IO Tensor
xavierUniform' :: [Int] -> IO Tensor
xavierUniform' = Float -> [Int] -> IO Tensor
xavierUniform Float
1.0

xavierNormal' :: [Int] -> IO Tensor
xavierNormal' :: [Int] -> IO Tensor
xavierNormal' = Float -> [Int] -> IO Tensor
xavierNormal Float
1.0