{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.NN.Recurrent.Cell.Elman where

import GHC.Generics
import Torch

data ElmanSpec = ElmanSpec
  { ElmanSpec -> Int
inputSize :: Int,
    ElmanSpec -> Int
hiddenSize :: Int
  }
  deriving (ElmanSpec -> ElmanSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ElmanSpec -> ElmanSpec -> Bool
$c/= :: ElmanSpec -> ElmanSpec -> Bool
== :: ElmanSpec -> ElmanSpec -> Bool
$c== :: ElmanSpec -> ElmanSpec -> Bool
Eq, Int -> ElmanSpec -> ShowS
[ElmanSpec] -> ShowS
ElmanSpec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ElmanSpec] -> ShowS
$cshowList :: [ElmanSpec] -> ShowS
show :: ElmanSpec -> String
$cshow :: ElmanSpec -> String
showsPrec :: Int -> ElmanSpec -> ShowS
$cshowsPrec :: Int -> ElmanSpec -> ShowS
Show)

data ElmanCell = ElmanCell
  { ElmanCell -> Parameter
weightsIH :: Parameter,
    ElmanCell -> Parameter
weightsHH :: Parameter,
    ElmanCell -> Parameter
biasIH :: Parameter,
    ElmanCell -> Parameter
biasHH :: Parameter
  }
  deriving (forall x. Rep ElmanCell x -> ElmanCell
forall x. ElmanCell -> Rep ElmanCell x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ElmanCell x -> ElmanCell
$cfrom :: forall x. ElmanCell -> Rep ElmanCell x
Generic, Int -> ElmanCell -> ShowS
[ElmanCell] -> ShowS
ElmanCell -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ElmanCell] -> ShowS
$cshowList :: [ElmanCell] -> ShowS
show :: ElmanCell -> String
$cshow :: ElmanCell -> String
showsPrec :: Int -> ElmanCell -> ShowS
$cshowsPrec :: Int -> ElmanCell -> ShowS
Show)

elmanCellForward ::
  -- | cell parameters
  ElmanCell ->
  -- | input
  Tensor ->
  -- | hidden
  Tensor ->
  -- | output
  Tensor
elmanCellForward :: ElmanCell -> Tensor -> Tensor -> Tensor
elmanCellForward ElmanCell {Parameter
biasHH :: Parameter
biasIH :: Parameter
weightsHH :: Parameter
weightsIH :: Parameter
biasHH :: ElmanCell -> Parameter
biasIH :: ElmanCell -> Parameter
weightsHH :: ElmanCell -> Parameter
weightsIH :: ElmanCell -> Parameter
..} Tensor
input Tensor
hidden =
  Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
rnnReluCell Tensor
weightsIH' Tensor
weightsHH' Tensor
biasIH' Tensor
biasHH' Tensor
hidden Tensor
input
  where
    weightsIH' :: Tensor
weightsIH' = Parameter -> Tensor
toDependent Parameter
weightsIH
    weightsHH' :: Tensor
weightsHH' = Parameter -> Tensor
toDependent Parameter
weightsHH
    biasIH' :: Tensor
biasIH' = Parameter -> Tensor
toDependent Parameter
biasIH
    biasHH' :: Tensor
biasHH' = Parameter -> Tensor
toDependent Parameter
biasIH

instance Parameterized ElmanCell

instance Randomizable ElmanSpec ElmanCell where
  sample :: ElmanSpec -> IO ElmanCell
sample ElmanSpec {Int
hiddenSize :: Int
inputSize :: Int
hiddenSize :: ElmanSpec -> Int
inputSize :: ElmanSpec -> Int
..} = do
    Parameter
weightsIH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
inputSize]
    Parameter
weightsHH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
hiddenSize]
    Parameter
biasIH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
    Parameter
biasHH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Parameter -> Parameter -> ElmanCell
ElmanCell Parameter
weightsIH Parameter
weightsHH Parameter
biasIH Parameter
biasHH