{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.NN.Recurrent.Cell.LSTM where

import GHC.Generics
import Torch

data LSTMSpec = LSTMSpec
  { LSTMSpec -> Int
inputSize :: Int,
    LSTMSpec -> Int
hiddenSize :: Int
  }
  deriving (LSTMSpec -> LSTMSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LSTMSpec -> LSTMSpec -> Bool
$c/= :: LSTMSpec -> LSTMSpec -> Bool
== :: LSTMSpec -> LSTMSpec -> Bool
$c== :: LSTMSpec -> LSTMSpec -> Bool
Eq, Int -> LSTMSpec -> ShowS
[LSTMSpec] -> ShowS
LSTMSpec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMSpec] -> ShowS
$cshowList :: [LSTMSpec] -> ShowS
show :: LSTMSpec -> String
$cshow :: LSTMSpec -> String
showsPrec :: Int -> LSTMSpec -> ShowS
$cshowsPrec :: Int -> LSTMSpec -> ShowS
Show)

data LSTMCell = LSTMCell
  { LSTMCell -> Parameter
weightsIH :: Parameter,
    LSTMCell -> Parameter
weightsHH :: Parameter,
    LSTMCell -> Parameter
biasIH :: Parameter,
    LSTMCell -> Parameter
biasHH :: Parameter
  }
  deriving (forall x. Rep LSTMCell x -> LSTMCell
forall x. LSTMCell -> Rep LSTMCell x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep LSTMCell x -> LSTMCell
$cfrom :: forall x. LSTMCell -> Rep LSTMCell x
Generic, Int -> LSTMCell -> ShowS
[LSTMCell] -> ShowS
LSTMCell -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMCell] -> ShowS
$cshowList :: [LSTMCell] -> ShowS
show :: LSTMCell -> String
$cshow :: LSTMCell -> String
showsPrec :: Int -> LSTMCell -> ShowS
$cshowsPrec :: Int -> LSTMCell -> ShowS
Show)

lstmCellForward ::
  -- | cell parameters
  LSTMCell ->
  -- | (hidden, cell)
  (Tensor, Tensor) ->
  -- | input
  Tensor ->
  -- | output (hidden, cell)
  (Tensor, Tensor)
lstmCellForward :: LSTMCell -> (Tensor, Tensor) -> Tensor -> (Tensor, Tensor)
lstmCellForward LSTMCell {Parameter
biasHH :: Parameter
biasIH :: Parameter
weightsHH :: Parameter
weightsIH :: Parameter
biasHH :: LSTMCell -> Parameter
biasIH :: LSTMCell -> Parameter
weightsHH :: LSTMCell -> Parameter
weightsIH :: LSTMCell -> Parameter
..} (Tensor, Tensor)
hidden Tensor
input =
  Tensor
-> Tensor
-> Tensor
-> Tensor
-> (Tensor, Tensor)
-> Tensor
-> (Tensor, Tensor)
lstmCell Tensor
weightsIH' Tensor
weightsHH' Tensor
biasIH' Tensor
biasHH' (Tensor, Tensor)
hidden Tensor
input
  where
    weightsIH' :: Tensor
weightsIH' = Parameter -> Tensor
toDependent Parameter
weightsIH
    weightsHH' :: Tensor
weightsHH' = Parameter -> Tensor
toDependent Parameter
weightsHH
    biasIH' :: Tensor
biasIH' = Parameter -> Tensor
toDependent Parameter
biasIH
    biasHH' :: Tensor
biasHH' = Parameter -> Tensor
toDependent Parameter
biasHH

instance Parameterized LSTMCell

instance Randomizable LSTMSpec LSTMCell where
  sample :: LSTMSpec -> IO LSTMCell
sample LSTMSpec {Int
hiddenSize :: Int
inputSize :: Int
hiddenSize :: LSTMSpec -> Int
inputSize :: LSTMSpec -> Int
..} = do
    -- x4 dimension calculations - see https://pytorch.org/docs/master/generated/torch.nn.LSTMCell.html
    Parameter
weightsIH' <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
4 forall a. Num a => a -> a -> a
* Int
hiddenSize, Int
inputSize]
    Parameter
weightsHH' <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
4 forall a. Num a => a -> a -> a
* Int
hiddenSize, Int
hiddenSize]
    Parameter
biasIH' <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
4 forall a. Num a => a -> a -> a
* Int
hiddenSize]
    Parameter
biasHH' <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
4 forall a. Num a => a -> a -> a
* Int
hiddenSize]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
      LSTMCell
        { weightsIH :: Parameter
weightsIH = Parameter
weightsIH',
          weightsHH :: Parameter
weightsHH = Parameter
weightsHH',
          biasIH :: Parameter
biasIH = Parameter
biasIH',
          biasHH :: Parameter
biasHH = Parameter
biasHH'
        }
    where
      scale :: Float
scale = forall a. Floating a => a -> a
Prelude.sqrt forall a b. (a -> b) -> a -> b
$ Float
1.0 forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
hiddenSize :: Float
      initScale :: Tensor -> Tensor
initScale = forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
scale forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
scale forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
2.0 :: Float)