{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Torch.NN.Recurrent.Cell.Elman where
import GHC.Generics
import Torch
data ElmanSpec = ElmanSpec
{ ElmanSpec -> Int
inputSize :: Int,
ElmanSpec -> Int
hiddenSize :: Int
}
deriving (ElmanSpec -> ElmanSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ElmanSpec -> ElmanSpec -> Bool
$c/= :: ElmanSpec -> ElmanSpec -> Bool
== :: ElmanSpec -> ElmanSpec -> Bool
$c== :: ElmanSpec -> ElmanSpec -> Bool
Eq, Int -> ElmanSpec -> ShowS
[ElmanSpec] -> ShowS
ElmanSpec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ElmanSpec] -> ShowS
$cshowList :: [ElmanSpec] -> ShowS
show :: ElmanSpec -> String
$cshow :: ElmanSpec -> String
showsPrec :: Int -> ElmanSpec -> ShowS
$cshowsPrec :: Int -> ElmanSpec -> ShowS
Show)
data ElmanCell = ElmanCell
{ ElmanCell -> Parameter
weightsIH :: Parameter,
ElmanCell -> Parameter
weightsHH :: Parameter,
ElmanCell -> Parameter
biasIH :: Parameter,
ElmanCell -> Parameter
biasHH :: Parameter
}
deriving (forall x. Rep ElmanCell x -> ElmanCell
forall x. ElmanCell -> Rep ElmanCell x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ElmanCell x -> ElmanCell
$cfrom :: forall x. ElmanCell -> Rep ElmanCell x
Generic, Int -> ElmanCell -> ShowS
[ElmanCell] -> ShowS
ElmanCell -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ElmanCell] -> ShowS
$cshowList :: [ElmanCell] -> ShowS
show :: ElmanCell -> String
$cshow :: ElmanCell -> String
showsPrec :: Int -> ElmanCell -> ShowS
$cshowsPrec :: Int -> ElmanCell -> ShowS
Show)
elmanCellForward ::
ElmanCell ->
Tensor ->
Tensor ->
Tensor
elmanCellForward :: ElmanCell -> Tensor -> Tensor -> Tensor
elmanCellForward ElmanCell {Parameter
biasHH :: Parameter
biasIH :: Parameter
weightsHH :: Parameter
weightsIH :: Parameter
biasHH :: ElmanCell -> Parameter
biasIH :: ElmanCell -> Parameter
weightsHH :: ElmanCell -> Parameter
weightsIH :: ElmanCell -> Parameter
..} Tensor
input Tensor
hidden =
Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
rnnReluCell Tensor
weightsIH' Tensor
weightsHH' Tensor
biasIH' Tensor
biasHH' Tensor
hidden Tensor
input
where
weightsIH' :: Tensor
weightsIH' = Parameter -> Tensor
toDependent Parameter
weightsIH
weightsHH' :: Tensor
weightsHH' = Parameter -> Tensor
toDependent Parameter
weightsHH
biasIH' :: Tensor
biasIH' = Parameter -> Tensor
toDependent Parameter
biasIH
biasHH' :: Tensor
biasHH' = Parameter -> Tensor
toDependent Parameter
biasIH
instance Parameterized ElmanCell
instance Randomizable ElmanSpec ElmanCell where
sample :: ElmanSpec -> IO ElmanCell
sample ElmanSpec {Int
hiddenSize :: Int
inputSize :: Int
hiddenSize :: ElmanSpec -> Int
inputSize :: ElmanSpec -> Int
..} = do
Parameter
weightsIH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
inputSize]
Parameter
weightsHH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
hiddenSize]
Parameter
biasIH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
Parameter
biasHH <- Tensor -> IO Parameter
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Parameter -> Parameter -> ElmanCell
ElmanCell Parameter
weightsIH Parameter
weightsHH Parameter
biasIH Parameter
biasHH