{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Torch.Typed.NN.Recurrent.Cell.GRU where
import Data.List
( foldl',
scanl',
)
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.NN as A
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (linear)
import Torch.Typed.NN.Dropout
import Torch.Typed.Parameter
import Torch.Typed.Tensor
data
GRUCellSpec
(inputDim :: Nat)
(hiddenDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
=
GRUCellSpec
deriving (Int -> GRUCellSpec inputDim hiddenDim dtype device -> ShowS
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int -> GRUCellSpec inputDim hiddenDim dtype device -> ShowS
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
[GRUCellSpec inputDim hiddenDim dtype device] -> ShowS
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GRUCellSpec inputDim hiddenDim dtype device] -> ShowS
$cshowList :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
[GRUCellSpec inputDim hiddenDim dtype device] -> ShowS
show :: GRUCellSpec inputDim hiddenDim dtype device -> String
$cshow :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device -> String
showsPrec :: Int -> GRUCellSpec inputDim hiddenDim dtype device -> ShowS
$cshowsPrec :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int -> GRUCellSpec inputDim hiddenDim dtype device -> ShowS
Show, GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
$c/= :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
== :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
$c== :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
Eq, GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Ordering
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Eq (GRUCellSpec inputDim hiddenDim dtype device)
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Ordering
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
$cmin :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
max :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
$cmax :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
>= :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
$c>= :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
> :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
$c> :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
<= :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
$c<= :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
< :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
$c< :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Bool
compare :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Ordering
$ccompare :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device -> Ordering
Ord, forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep (GRUCellSpec inputDim hiddenDim dtype device) x
-> GRUCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
GRUCellSpec inputDim hiddenDim dtype device
-> Rep (GRUCellSpec inputDim hiddenDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep (GRUCellSpec inputDim hiddenDim dtype device) x
-> GRUCellSpec inputDim hiddenDim dtype device
$cfrom :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
GRUCellSpec inputDim hiddenDim dtype device
-> Rep (GRUCellSpec inputDim hiddenDim dtype device) x
Generic, Int -> GRUCellSpec inputDim hiddenDim dtype device
GRUCellSpec inputDim hiddenDim dtype device -> Int
GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int -> GRUCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device -> Int
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
$cenumFromThenTo :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
enumFromTo :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
$cenumFromTo :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
enumFromThen :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
$cenumFromThen :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
enumFrom :: GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
$cenumFrom :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> [GRUCellSpec inputDim hiddenDim dtype device]
fromEnum :: GRUCellSpec inputDim hiddenDim dtype device -> Int
$cfromEnum :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device -> Int
toEnum :: Int -> GRUCellSpec inputDim hiddenDim dtype device
$ctoEnum :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int -> GRUCellSpec inputDim hiddenDim dtype device
pred :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
$cpred :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
succ :: GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
$csucc :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
-> GRUCellSpec inputDim hiddenDim dtype device
Enum, GRUCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
forall a. a -> a -> Bounded a
maxBound :: GRUCellSpec inputDim hiddenDim dtype device
$cmaxBound :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
minBound :: GRUCellSpec inputDim hiddenDim dtype device
$cminBound :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCellSpec inputDim hiddenDim dtype device
Bounded)
data
GRUCell
(inputDim :: Nat)
(hiddenDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat)) = GRUCell
{
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim, inputDim]
gruCell_w_ih :: Parameter device dtype '[3 * hiddenDim, inputDim],
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim, hiddenDim]
gruCell_w_hh :: Parameter device dtype '[3 * hiddenDim, hiddenDim],
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim]
gruCell_b_ih :: Parameter device dtype '[3 * hiddenDim],
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim]
gruCell_b_hh :: Parameter device dtype '[3 * hiddenDim]
}
deriving (Int -> GRUCell inputDim hiddenDim dtype device -> ShowS
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int -> GRUCell inputDim hiddenDim dtype device -> ShowS
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
[GRUCell inputDim hiddenDim dtype device] -> ShowS
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GRUCell inputDim hiddenDim dtype device] -> ShowS
$cshowList :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
[GRUCell inputDim hiddenDim dtype device] -> ShowS
show :: GRUCell inputDim hiddenDim dtype device -> String
$cshow :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device -> String
showsPrec :: Int -> GRUCell inputDim hiddenDim dtype device -> ShowS
$cshowsPrec :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int -> GRUCell inputDim hiddenDim dtype device -> ShowS
Show, forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep (GRUCell inputDim hiddenDim dtype device) x
-> GRUCell inputDim hiddenDim dtype device
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
GRUCell inputDim hiddenDim dtype device
-> Rep (GRUCell inputDim hiddenDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep (GRUCell inputDim hiddenDim dtype device) x
-> GRUCell inputDim hiddenDim dtype device
$cfrom :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
GRUCell inputDim hiddenDim dtype device
-> Rep (GRUCell inputDim hiddenDim dtype device) x
Generic, forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> HList (Parameters (GRUCell inputDim hiddenDim dtype device))
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> HList (Parameters (GRUCell inputDim hiddenDim dtype device))
-> GRUCell inputDim hiddenDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: GRUCell inputDim hiddenDim dtype device
-> HList (Parameters (GRUCell inputDim hiddenDim dtype device))
-> GRUCell inputDim hiddenDim dtype device
$creplaceParameters :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> HList (Parameters (GRUCell inputDim hiddenDim dtype device))
-> GRUCell inputDim hiddenDim dtype device
flattenParameters :: GRUCell inputDim hiddenDim dtype device
-> HList (Parameters (GRUCell inputDim hiddenDim dtype device))
$cflattenParameters :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> HList (Parameters (GRUCell inputDim hiddenDim dtype device))
Parameterized)
instance
( KnownDevice device,
KnownDType dtype,
KnownNat inputDim,
KnownNat hiddenDim,
RandDTypeIsValid device dtype
) =>
A.Randomizable
(GRUCellSpec inputDim hiddenDim dtype device)
(GRUCell inputDim hiddenDim dtype device)
where
sample :: GRUCellSpec inputDim hiddenDim dtype device
-> IO (GRUCell inputDim hiddenDim dtype device)
sample GRUCellSpec inputDim hiddenDim dtype device
GRUCellSpec =
forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
Parameter device dtype '[3 * hiddenDim, inputDim]
-> Parameter device dtype '[3 * hiddenDim, hiddenDim]
-> Parameter device dtype '[3 * hiddenDim]
-> Parameter device dtype '[3 * hiddenDim]
-> GRUCell inputDim hiddenDim dtype device
GRUCell
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
gruCellForward ::
forall inputDim hiddenDim batchSize dtype device.
( KnownDType dtype,
KnownNat inputDim,
KnownNat hiddenDim,
KnownNat batchSize
) =>
GRUCell inputDim hiddenDim dtype device ->
Tensor device dtype '[batchSize, hiddenDim] ->
Tensor device dtype '[batchSize, inputDim] ->
Tensor device dtype '[batchSize, hiddenDim]
gruCellForward :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(batchSize :: Natural) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
KnownNat batchSize) =>
GRUCell inputDim hiddenDim dtype device
-> Tensor device dtype '[batchSize, hiddenDim]
-> Tensor device dtype '[batchSize, inputDim]
-> Tensor device dtype '[batchSize, hiddenDim]
gruCellForward GRUCell {Parameter device dtype '[3 * hiddenDim, inputDim]
Parameter device dtype '[3 * hiddenDim, hiddenDim]
Parameter device dtype '[3 * hiddenDim]
gruCell_b_hh :: Parameter device dtype '[3 * hiddenDim]
gruCell_b_ih :: Parameter device dtype '[3 * hiddenDim]
gruCell_w_hh :: Parameter device dtype '[3 * hiddenDim, hiddenDim]
gruCell_w_ih :: Parameter device dtype '[3 * hiddenDim, inputDim]
gruCell_b_hh :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim]
gruCell_b_ih :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim]
gruCell_w_hh :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim, hiddenDim]
gruCell_w_ih :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(dtype :: DType) (device :: (DeviceType, Natural)).
GRUCell inputDim hiddenDim dtype device
-> Parameter device dtype '[3 * hiddenDim, inputDim]
..} =
forall (inputSize :: Natural) (hiddenSize :: Natural)
(batchSize :: Natural) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype '[3 * hiddenSize, inputSize]
-> Tensor device dtype '[3 * hiddenSize, hiddenSize]
-> Tensor device dtype '[3 * hiddenSize]
-> Tensor device dtype '[3 * hiddenSize]
-> Tensor device dtype '[batchSize, hiddenSize]
-> Tensor device dtype '[batchSize, inputSize]
-> Tensor device dtype '[batchSize, hiddenSize]
gruCell
(forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[3 * hiddenDim, inputDim]
gruCell_w_ih)
(forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[3 * hiddenDim, hiddenDim]
gruCell_w_hh)
(forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[3 * hiddenDim]
gruCell_b_ih)
(forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[3 * hiddenDim]
gruCell_b_hh)
gruFold ::
forall inputDim hiddenDim batchSize dtype device.
( KnownDType dtype,
KnownNat inputDim,
KnownNat hiddenDim,
KnownNat batchSize
) =>
GRUCell inputDim hiddenDim dtype device ->
Tensor device dtype '[batchSize, hiddenDim] ->
[Tensor device dtype '[batchSize, inputDim]] ->
Tensor device dtype '[batchSize, hiddenDim]
gruFold :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(batchSize :: Natural) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
KnownNat batchSize) =>
GRUCell inputDim hiddenDim dtype device
-> Tensor device dtype '[batchSize, hiddenDim]
-> [Tensor device dtype '[batchSize, inputDim]]
-> Tensor device dtype '[batchSize, hiddenDim]
gruFold GRUCell inputDim hiddenDim dtype device
cell = forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall (inputDim :: Natural) (hiddenDim :: Natural)
(batchSize :: Natural) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
KnownNat batchSize) =>
GRUCell inputDim hiddenDim dtype device
-> Tensor device dtype '[batchSize, hiddenDim]
-> Tensor device dtype '[batchSize, inputDim]
-> Tensor device dtype '[batchSize, hiddenDim]
gruCellForward GRUCell inputDim hiddenDim dtype device
cell)
gruCellScan ::
forall inputDim hiddenDim batchSize dtype device.
( KnownDType dtype,
KnownNat inputDim,
KnownNat hiddenDim,
KnownNat batchSize
) =>
GRUCell inputDim hiddenDim dtype device ->
Tensor device dtype '[batchSize, hiddenDim] ->
[Tensor device dtype '[batchSize, inputDim]] ->
[Tensor device dtype '[batchSize, hiddenDim]]
gruCellScan :: forall (inputDim :: Natural) (hiddenDim :: Natural)
(batchSize :: Natural) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
KnownNat batchSize) =>
GRUCell inputDim hiddenDim dtype device
-> Tensor device dtype '[batchSize, hiddenDim]
-> [Tensor device dtype '[batchSize, inputDim]]
-> [Tensor device dtype '[batchSize, hiddenDim]]
gruCellScan GRUCell inputDim hiddenDim dtype device
cell = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl' (forall (inputDim :: Natural) (hiddenDim :: Natural)
(batchSize :: Natural) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
KnownNat batchSize) =>
GRUCell inputDim hiddenDim dtype device
-> Tensor device dtype '[batchSize, hiddenDim]
-> Tensor device dtype '[batchSize, inputDim]
-> Tensor device dtype '[batchSize, hiddenDim]
gruCellForward GRUCell inputDim hiddenDim dtype device
cell)