{-# LANGUAGE DeriveGeneric #-}

module Torch.Typed.NN.Recurrent.Auxiliary where

import GHC.Generics
import Torch.Functional (mulScalar, subScalar)
import Torch.Tensor

data RNNInitialization
  = ConstantInitialization
  | LearnedInitialization
  deriving (Int -> RNNInitialization -> ShowS
[RNNInitialization] -> ShowS
RNNInitialization -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RNNInitialization] -> ShowS
$cshowList :: [RNNInitialization] -> ShowS
show :: RNNInitialization -> String
$cshow :: RNNInitialization -> String
showsPrec :: Int -> RNNInitialization -> ShowS
$cshowsPrec :: Int -> RNNInitialization -> ShowS
Show, forall x. Rep RNNInitialization x -> RNNInitialization
forall x. RNNInitialization -> Rep RNNInitialization x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RNNInitialization x -> RNNInitialization
$cfrom :: forall x. RNNInitialization -> Rep RNNInitialization x
Generic)

-- TODO: This is taken from the initializers example code and should be replaced with cannonical,
-- tested versions. However, even a potentially incorrect implementation will likely perform
-- better than an ad-hoc random-normal distribution.

-- | Fan-in / Fan-out scaling calculation
calculateFan :: [Int] -> (Int, Int)
calculateFan :: [Int] -> (Int, Int)
calculateFan [Int]
shape
  | Int
dimT forall a. Ord a => a -> a -> Bool
< Int
2 =
    forall a. HasCallStack => String -> a
error
      String
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
  | Int
dimT forall a. Eq a => a -> a -> Bool
== Int
2 =
    (Int
numInputFmaps, Int
numOutputFmaps)
  | Bool
otherwise =
    (Int
numInputFmaps forall a. Num a => a -> a -> a
* Int
receptiveFieldSize, Int
numOutputFmaps forall a. Num a => a -> a -> a
* Int
receptiveFieldSize)
  where
    dimT :: Int
dimT = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape
    numInputFmaps :: Int
numInputFmaps = [Int]
shape forall a. [a] -> Int -> a
!! Int
1
    numOutputFmaps :: Int
numOutputFmaps = [Int]
shape forall a. [a] -> Int -> a
!! Int
0
    receptiveFieldSize :: Int
receptiveFieldSize = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [Int]
shape

-- | Xavier Initialization - Uniform
xavierUniformFIXME :: Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME :: Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME Tensor
init Float
gain [Int]
shape =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$
      forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
  where
    (Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
    std :: Float
std = Float
gain forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Float
2.0 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))
    bound :: Float
bound = forall a. Floating a => a -> a
sqrt Float
3.0 forall a. Num a => a -> a -> a
* Float
std