{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Torch.Typed.NN.Recurrent.LSTM where
import Data.Kind
import Data.Proxy (Proxy (..))
import Foreign.ForeignPtr
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.Environment
import System.IO.Unsafe
import qualified Torch.Autograd as A
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.NN as A
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (sqrt)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Recurrent.Auxiliary
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (tanh)
data
LSTMLayerSpec
(inputSize :: Nat)
(hiddenSize :: Nat)
(directionality :: RNNDirectionality)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= LSTMLayerSpec
deriving (Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
[LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
[LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
show :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
showsPrec :: Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
Show, LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
== :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
Eq)
data
LSTMLayer
(inputSize :: Nat)
(hiddenSize :: Nat)
(directionality :: RNNDirectionality)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
LSTMUnidirectionalLayer ::
Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMBidirectionalLayer ::
Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
deriving instance Show (LSTMLayer inputSize hiddenSize directionality dtype device)
instance Parameterized (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device) where
type
Parameters (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device) =
'[ Parameter device dtype (LSTMWIShape hiddenSize inputSize),
Parameter device dtype (LSTMWHShape hiddenSize inputSize),
Parameter device dtype (LSTMBIShape hiddenSize inputSize),
Parameter device dtype (LSTMBHShape hiddenSize inputSize)
]
flattenParameters :: LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
(Parameters
(LSTMLayer inputSize hiddenSize 'Unidirectional dtype device))
flattenParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh) =
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil
replaceParameters :: LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
(Parameters
(LSTMLayer inputSize hiddenSize 'Unidirectional dtype device))
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
replaceParameters LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
_ (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh :. HList '[]
R:HListk[] Type
HNil) =
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh
instance Parameterized (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device) where
type
Parameters (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device) =
'[ Parameter device dtype (LSTMWIShape hiddenSize inputSize),
Parameter device dtype (LSTMWHShape hiddenSize inputSize),
Parameter device dtype (LSTMBIShape hiddenSize inputSize),
Parameter device dtype (LSTMBHShape hiddenSize inputSize),
Parameter device dtype (LSTMWIShape hiddenSize inputSize),
Parameter device dtype (LSTMWHShape hiddenSize inputSize),
Parameter device dtype (LSTMBIShape hiddenSize inputSize),
Parameter device dtype (LSTMBHShape hiddenSize inputSize)
]
flattenParameters :: LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
(Parameters
(LSTMLayer inputSize hiddenSize 'Bidirectional dtype device))
flattenParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh') =
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil
replaceParameters :: LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
(Parameters
(LSTMLayer inputSize hiddenSize 'Bidirectional dtype device))
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
replaceParameters LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
_ (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh :. Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh' :. HList '[]
R:HListk[] Type
HNil) =
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh'
instance
( RandDTypeIsValid device dtype,
KnownNat inputSize,
KnownNat hiddenSize,
KnownDType dtype,
KnownDevice device
) =>
A.Randomizable
(LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device)
(LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
where
sample :: LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device
-> IO (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
sample LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device
_ =
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
instance
( RandDTypeIsValid device dtype,
KnownNat inputSize,
KnownNat hiddenSize,
KnownDType dtype,
KnownDevice device
) =>
A.Randomizable
(LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device)
(LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
where
sample :: LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device
-> IO (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
sample LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device
_ =
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
instance A.Parameterized (LSTMLayer inputSize hiddenSize directionality dtype device) where
flattenParameters :: LSTMLayer inputSize hiddenSize directionality dtype device
-> [Parameter]
flattenParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh) =
[ forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi,
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh,
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi,
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh
]
flattenParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh') =
[ forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi,
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh,
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi,
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh,
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi',
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh',
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi',
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh'
]
_replaceParameters :: LSTMLayer inputSize hiddenSize directionality dtype device
-> ParamStream
(LSTMLayer inputSize hiddenSize directionality dtype device)
_replaceParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh) = do
Parameter
wi <- ParamStream Parameter
A.nextParameter
Parameter
wh <- ParamStream Parameter
A.nextParameter
Parameter
bi <- ParamStream Parameter
A.nextParameter
Parameter
bh <- ParamStream Parameter
A.nextParameter
forall (m :: Type -> Type) a. Monad m => a -> m a
return
( forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh)
)
_replaceParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh') = do
Parameter
wi <- ParamStream Parameter
A.nextParameter
Parameter
wh <- ParamStream Parameter
A.nextParameter
Parameter
bi <- ParamStream Parameter
A.nextParameter
Parameter
bh <- ParamStream Parameter
A.nextParameter
Parameter
wi' <- ParamStream Parameter
A.nextParameter
Parameter
wh' <- ParamStream Parameter
A.nextParameter
Parameter
bi' <- ParamStream Parameter
A.nextParameter
Parameter
bh' <- ParamStream Parameter
A.nextParameter
forall (m :: Type -> Type) a. Monad m => a -> m a
return
( forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi')
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh')
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi')
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh')
)
data
LSTMLayerStackSpec
(inputSize :: Nat)
(hiddenSize :: Nat)
(numLayers :: Nat)
(directionality :: RNNDirectionality)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= LSTMLayerStackSpec
deriving (Int
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
show :: LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> String
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> String
showsPrec :: Int
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
Show, LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> Bool
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> Bool
== :: LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> Bool
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> Bool
Eq)
data
LSTMLayerStack
(inputSize :: Nat)
(hiddenSize :: Nat)
(numLayers :: Nat)
(directionality :: RNNDirectionality)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
LSTMLayer1 ::
LSTMLayer inputSize hiddenSize directionality dtype device ->
LSTMLayerStack inputSize hiddenSize 1 directionality dtype device
LSTMLayerK ::
LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device ->
LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
LSTMLayerStack inputSize hiddenSize (numLayers + 1) directionality dtype device
deriving instance Show (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
class LSTMLayerStackParameterized (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
type LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device :: [Type]
lstmLayerStackFlattenParameters ::
Proxy flag ->
LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
HList (LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackReplaceParameters ::
Proxy flag ->
LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
HList (LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device) ->
LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device
instance
Parameterized (LSTMLayer inputSize hiddenSize directionality dtype device) =>
LSTMLayerStackParameterized 'False inputSize hiddenSize 1 directionality dtype device
where
type
LSTMLayerStackParameters 'False inputSize hiddenSize 1 directionality dtype device =
Parameters (LSTMLayer inputSize hiddenSize directionality dtype device)
lstmLayerStackFlattenParameters :: Proxy 'False
-> LSTMLayerStack
inputSize hiddenSize 1 directionality dtype device
-> HList
(LSTMLayerStackParameters
'False inputSize hiddenSize 1 directionality dtype device)
lstmLayerStackFlattenParameters Proxy 'False
_ (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer) = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer
lstmLayerStackReplaceParameters :: Proxy 'False
-> LSTMLayerStack
inputSize hiddenSize 1 directionality dtype device
-> HList
(LSTMLayerStackParameters
'False inputSize hiddenSize 1 directionality dtype device)
-> LSTMLayerStack
inputSize hiddenSize 1 directionality dtype device
lstmLayerStackReplaceParameters Proxy 'False
_ (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer) HList
(LSTMLayerStackParameters
'False inputSize hiddenSize 1 directionality dtype device)
parameters = forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer HList
(LSTMLayerStackParameters
'False inputSize hiddenSize 1 directionality dtype device)
parameters
instance
( Parameterized
( LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
),
Parameterized (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device),
HAppendFD
(Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device))
(Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device))
( Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
++ Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
)
) =>
LSTMLayerStackParameterized 'True inputSize hiddenSize numLayers directionality dtype device
where
type
LSTMLayerStackParameters 'True inputSize hiddenSize numLayers directionality dtype device =
Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
++ Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
lstmLayerStackFlattenParameters :: Proxy 'True
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> HList
(LSTMLayerStackParameters
'True inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackFlattenParameters Proxy 'True
_ (LSTMLayerK LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
lstmLayer LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack) =
let parameters :: HList
(Parameters
(LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device))
parameters = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
lstmLayer
parameters' :: HList
(Parameters
(LSTMLayerStack
inputSize hiddenSize (numLayers - 1) directionality dtype device))
parameters' = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters @(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device) LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack
in HList
(Parameters
(LSTMLayerStack
inputSize hiddenSize (numLayers - 1) directionality dtype device))
parameters' forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList
(Parameters
(LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device))
parameters
lstmLayerStackReplaceParameters :: Proxy 'True
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> HList
(LSTMLayerStackParameters
'True inputSize hiddenSize numLayers directionality dtype device)
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstmLayerStackReplaceParameters Proxy 'True
_ (LSTMLayerK LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
lstmLayer LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack) HList
(LSTMLayerStackParameters
'True inputSize hiddenSize numLayers directionality dtype device)
parameters'' =
let (HList
(LSTMLayerStackParameters
(OrdCond (CmpNat 2 (numLayers - 1)) 'True 'True 'False)
inputSize
hiddenSize
(numLayers - 1)
directionality
dtype
device)
parameters', HList
(Parameters
(LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device))
parameters) = forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList
(LSTMLayerStackParameters
'True inputSize hiddenSize numLayers directionality dtype device)
parameters''
lstmLayer' :: LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
lstmLayer' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
lstmLayer HList
(Parameters
(LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device))
parameters
lstmLayerStack' :: LSTMLayerStack
inputSize hiddenSize (numLayers - 1) directionality dtype device
lstmLayerStack' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters @(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device) LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack HList
(LSTMLayerStackParameters
(OrdCond (CmpNat 2 (numLayers - 1)) 'True 'True 'False)
inputSize
hiddenSize
(numLayers - 1)
directionality
dtype
device)
parameters'
in forall (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)) (inputSize :: Natural)
(numLayers :: Natural).
LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
lstmLayer' LSTMLayerStack
inputSize hiddenSize (numLayers - 1) directionality dtype device
lstmLayerStack'
instance
( 1 <= numLayers,
(2 <=? numLayers) ~ flag,
LSTMLayerStackParameterized flag inputSize hiddenSize numLayers directionality dtype device
) =>
Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
where
type
Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device) =
LSTMLayerStackParameters (2 <=? numLayers) inputSize hiddenSize numLayers directionality dtype device
flattenParameters :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> HList
(Parameters
(LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device))
flattenParameters = forall (flag :: Bool) (inputSize :: Natural)
(hiddenSize :: Natural) (numLayers :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerStackParameterized
flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> HList
(LSTMLayerStackParameters
flag inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackFlattenParameters (forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)
replaceParameters :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> HList
(Parameters
(LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device))
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
replaceParameters = forall (flag :: Bool) (inputSize :: Natural)
(hiddenSize :: Natural) (numLayers :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerStackParameterized
flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> HList
(LSTMLayerStackParameters
flag inputSize hiddenSize numLayers directionality dtype device)
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstmLayerStackReplaceParameters (forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)
class LSTMLayerStackRandomizable (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
lstmLayerStackSample ::
Proxy flag ->
LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device ->
IO (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
instance
( A.Randomizable
(LSTMLayerSpec inputSize hiddenSize directionality dtype device)
(LSTMLayer inputSize hiddenSize directionality dtype device)
) =>
LSTMLayerStackRandomizable 'False inputSize hiddenSize 1 directionality dtype device
where
lstmLayerStackSample :: Proxy 'False
-> LSTMLayerStackSpec
inputSize hiddenSize 1 directionality dtype device
-> IO
(LSTMLayerStack inputSize hiddenSize 1 directionality dtype device)
lstmLayerStackSample Proxy 'False
_ LSTMLayerStackSpec
inputSize hiddenSize 1 directionality dtype device
_ = forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall spec f. Randomizable spec f => spec -> IO f
sample forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
LSTMLayerSpec @inputSize @hiddenSize @directionality @dtype @device)
instance
( 1 <= numLayers,
A.Randomizable
(LSTMLayerSpec (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
(LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device),
A.Randomizable
(LSTMLayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
) =>
LSTMLayerStackRandomizable 'True inputSize hiddenSize numLayers directionality dtype device
where
lstmLayerStackSample :: Proxy 'True
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> IO
(LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackSample Proxy 'True
_ LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
_ =
forall (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)) (inputSize :: Natural)
(numLayers :: Natural).
LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall spec f. Randomizable spec f => spec -> IO f
sample forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
LSTMLayerSpec @(hiddenSize * NumberOfDirections directionality) @hiddenSize @directionality @dtype @device)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ( forall spec f. Randomizable spec f => spec -> IO f
sample
@(LSTMLayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
@(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStackSpec
)
instance
( 1 <= numLayers,
(2 <=? numLayers) ~ flag,
RandDTypeIsValid device dtype,
KnownDType dtype,
KnownDevice device,
LSTMLayerStackRandomizable flag inputSize hiddenSize numLayers directionality dtype device
) =>
Randomizable
(LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
(LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
where
sample :: LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> IO
(LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device)
sample = forall (flag :: Bool) (inputSize :: Natural)
(hiddenSize :: Natural) (numLayers :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayerStackRandomizable
flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
-> IO
(LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackSample (forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)
instance A.Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device) where
flattenParameters :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
flattenParameters (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer) =
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayer inputSize hiddenSize directionality dtype device
layer
flattenParameters (LSTMLayerK LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
stack LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layer) =
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
stack
forall a. [a] -> [a] -> [a]
++ forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layer
_replaceParameters :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
(LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device)
_replaceParameters (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer) = do
LSTMLayer inputSize hiddenSize directionality dtype device
layer' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayer inputSize hiddenSize directionality dtype device
layer
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer'
_replaceParameters (LSTMLayerK LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
stack LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layer) = do
LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
stack' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
stack
LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layer' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layer
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (hiddenSize :: Natural)
(directionality :: RNNDirectionality) (dtype :: DType)
(device :: (DeviceType, Natural)) (inputSize :: Natural)
(numLayers :: Natural).
LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK LSTMLayer
(hiddenSize * NumberOfDirections directionality)
hiddenSize
directionality
dtype
device
stack' LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layer'
newtype
LSTMSpec
(inputSize :: Nat)
(hiddenSize :: Nat)
(numLayers :: Nat)
(directionality :: RNNDirectionality)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= LSTMSpec DropoutSpec
deriving (Int
-> LSTMSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMSpec
inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMSpec
inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMSpec
inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
show :: LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
showsPrec :: Int
-> LSTMSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMSpec
inputSize hiddenSize numLayers directionality dtype device
-> ShowS
Show, forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
(LSTMSpec
inputSize hiddenSize numLayers directionality dtype device)
x
-> LSTMSpec
inputSize hiddenSize numLayers directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
(LSTMSpec
inputSize hiddenSize numLayers directionality dtype device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
(LSTMSpec
inputSize hiddenSize numLayers directionality dtype device)
x
-> LSTMSpec
inputSize hiddenSize numLayers directionality dtype device
$cfrom :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)) x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
(LSTMSpec
inputSize hiddenSize numLayers directionality dtype device)
x
Generic)
data
LSTM
(inputSize :: Nat)
(hiddenSize :: Nat)
(numLayers :: Nat)
(directionality :: RNNDirectionality)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
LSTM ::
(1 <= numLayers) =>
{ forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack :: LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device,
forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_dropout :: Dropout
} ->
LSTM inputSize hiddenSize numLayers directionality dtype device
deriving instance Show (LSTM inputSize hiddenSize numLayers directionality dtype device)
instance
(1 <= numLayers) =>
Generic (LSTM inputSize hiddenSize numLayers directionality dtype device)
where
type
Rep (LSTM inputSize hiddenSize numLayers directionality dtype device) =
Rec0 (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
:*: Rec0 Dropout
from :: forall x.
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Rep
(LSTM inputSize hiddenSize numLayers directionality dtype device) x
from (LSTM {Dropout
LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
lstm_layer_stack :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
..}) = forall k i c (p :: k). c -> K1 i c p
K1 LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Dropout
lstm_dropout
to :: forall x.
Rep
(LSTM inputSize hiddenSize numLayers directionality dtype device) x
-> LSTM inputSize hiddenSize numLayers directionality dtype device
to (K1 LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layerStack :*: K1 Dropout
dropout) = forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
LSTM LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
layerStack Dropout
dropout
instance
( 1 <= numLayers,
Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device),
HAppendFD
(Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device))
(Parameters Dropout)
( Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
++ Parameters Dropout
)
) =>
Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device)
instance A.Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device) where
flattenParameters :: LSTM inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
flattenParameters LSTM {Dropout
LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
lstm_layer_stack :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
..} = forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack
_replaceParameters :: LSTM inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
(LSTM inputSize hiddenSize numLayers directionality dtype device)
_replaceParameters LSTM {Dropout
LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
lstm_layer_stack :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
..} = do
LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
LSTM
{ lstm_layer_stack :: LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack = LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack',
Dropout
lstm_dropout :: Dropout
lstm_dropout :: Dropout
..
}
xavierUniformLSTM ::
forall device dtype hiddenSize featureSize.
( KnownDType dtype,
KnownNat hiddenSize,
KnownNat featureSize,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM :: forall (device :: (DeviceType, Natural)) (dtype :: DType)
(hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM = do
Tensor device dtype '[4 * hiddenSize, featureSize]
init <- forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn :: IO (Tensor device dtype '[4 * hiddenSize, featureSize])
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME
(forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype '[4 * hiddenSize, featureSize]
init)
(Float
5.0 forall a. Fractional a => a -> a -> a
/ Float
3)
(forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> [Int]
shape @device @dtype @'[4 * hiddenSize, featureSize] Tensor device dtype '[4 * hiddenSize, featureSize]
init)
instance
( KnownDType dtype,
KnownDevice device,
KnownNat inputSize,
KnownNat hiddenSize,
KnownNat (NumberOfDirections directionality),
RandDTypeIsValid device dtype,
A.Randomizable
(LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
(LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device),
1 <= numLayers
) =>
A.Randomizable
(LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
(LSTM inputSize hiddenSize numLayers directionality dtype device)
where
sample :: LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
(LSTM inputSize hiddenSize numLayers directionality dtype device)
sample (LSTMSpec DropoutSpec
dropoutSpec) =
forall (numLayers :: Natural) (inputSize :: Natural)
(hiddenSize :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
LSTM
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStackSpec @inputSize @hiddenSize @numLayers @directionality @dtype @device)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropoutSpec
data
LSTMWithInitSpec
(inputSize :: Nat)
(hiddenSize :: Nat)
(numLayers :: Nat)
(directionality :: RNNDirectionality)
(initialization :: RNNInitialization)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
LSTMWithZerosInitSpec ::
forall inputSize hiddenSize numLayers directionality dtype device.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
LSTMWithConstInitSpec ::
forall inputSize hiddenSize numLayers directionality dtype device.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
LSTMWithLearnedInitSpec ::
forall inputSize hiddenSize numLayers directionality dtype device.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device
deriving instance Show (LSTMWithInitSpec inputSize hiddenSize numLayers directionality initialization dtype device)
data
LSTMWithInit
(inputSize :: Nat)
(hiddenSize :: Nat)
(numLayers :: Nat)
(directionality :: RNNDirectionality)
(initialization :: RNNInitialization)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
LSTMWithConstInit ::
forall inputSize hiddenSize numLayers directionality dtype device.
{ forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device,
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize],
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
} ->
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
LSTMWithLearnedInit ::
forall inputSize hiddenSize numLayers directionality dtype device.
{ forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device,
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize],
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
} ->
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
deriving instance Show (LSTMWithInit inputSize hiddenSize numLayers directionality initialization dtype device)
instance Generic (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) where
type
Rep (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) =
Rec0 (LSTM inputSize hiddenSize numLayers directionality dtype device)
:*: Rec0 (Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
:*: Rec0 (Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
from :: forall x.
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Rep
(LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device)
x
from (LSTMWithConstInit {Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..}) = forall k i c (p :: k). c -> K1 i c p
K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h
to :: forall x.
Rep
(LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device)
x
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
to (K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstm :*: K1 Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c :*: K1 Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h) = forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
LSTMWithConstInit LSTM inputSize hiddenSize numLayers directionality dtype device
lstm Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h
instance Generic (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) where
type
Rep (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) =
Rec0 (LSTM inputSize hiddenSize numLayers directionality dtype device)
:*: Rec0 (Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
:*: Rec0 (Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
from :: forall x.
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Rep
(LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device)
x
from (LSTMWithLearnedInit {Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..}) = forall k i c (p :: k). c -> K1 i c p
K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h
to :: forall x.
Rep
(LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device)
x
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
to (K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstm :*: K1 Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c :*: K1 Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h) = forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
LSTMWithLearnedInit LSTM inputSize hiddenSize numLayers directionality dtype device
lstm Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h
instance
( Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
HAppendFD
(Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device))
'[]
(Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device) ++ '[])
) =>
Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
instance
( Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
HAppendFD
(Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device))
'[ Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize],
Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
]
( Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device)
++ '[ Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize],
Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
]
)
) =>
Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
instance
( KnownNat hiddenSize,
KnownNat numLayers,
KnownNat (NumberOfDirections directionality),
KnownDType dtype,
KnownDevice device,
A.Randomizable
(LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
(LSTM inputSize hiddenSize numLayers directionality dtype device)
) =>
A.Randomizable
(LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
(LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
where
sample :: LSTMWithInitSpec
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> IO
(LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device)
sample (LSTMWithZerosInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec) =
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
LSTMWithConstInit
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
sample (LSTMWithConstInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
LSTMWithConstInit
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h
instance
( KnownNat hiddenSize,
KnownNat numLayers,
KnownNat (NumberOfDirections directionality),
KnownDType dtype,
KnownDevice device,
A.Randomizable
(LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
(LSTM inputSize hiddenSize numLayers directionality dtype device)
) =>
A.Randomizable
(LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
(LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
where
sample :: LSTMWithInitSpec
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> IO
(LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device)
sample s :: LSTMWithInitSpec
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
s@(LSTMWithLearnedInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
LSTMWithLearnedInit
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
c)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
h)
instance A.Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality initialization dtype device) where
flattenParameters :: LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device
-> [Parameter]
flattenParameters LSTMWithConstInit {Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} =
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm
flattenParameters LSTMWithLearnedInit {Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} =
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm
forall a. [a] -> [a] -> [a]
++ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam [Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c, Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h]
_replaceParameters :: LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device
-> ParamStream
(LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device)
_replaceParameters LSTMWithConstInit {Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'ConstantInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} = do
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
LSTMWithConstInit
{ lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm = LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm',
Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
..
}
_replaceParameters LSTMWithLearnedInit {Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
(numLayers :: Natural) (directionality :: RNNDirectionality)
(dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
'LearnedInitialization
dtype
device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} = do
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm
Parameter
lstmWithLearnedInit_c' <- ParamStream Parameter
A.nextParameter
Parameter
lstmWithLearnedInit_h' <- ParamStream Parameter
A.nextParameter
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
LSTMWithLearnedInit
{ lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm = LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm',
lstmWithLearnedInit_c :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c = forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
lstmWithLearnedInit_c',
lstmWithLearnedInit_h :: Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h = forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
lstmWithLearnedInit_h'
}
lstmForward ::
forall
shapeOrder
batchSize
seqLen
directionality
initialization
numLayers
inputSize
outputSize
hiddenSize
inputShape
outputShape
hxShape
parameters
tensorParameters
dtype
device.
( KnownNat (NumberOfDirections directionality),
KnownNat numLayers,
KnownNat batchSize,
KnownNat hiddenSize,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
parameters ~ Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device),
Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
ATen.Castable (HList tensorParameters) [D.ATenTensor],
HMap' ToDependent parameters tensorParameters
) =>
Bool ->
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device ->
Tensor device dtype inputShape ->
( Tensor device dtype outputShape,
Tensor device dtype hxShape,
Tensor device dtype hxShape
)
lstmForward :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
(seqLen :: Natural) (directionality :: RNNDirectionality)
(initialization :: RNNInitialization) (numLayers :: Natural)
(inputSize :: Natural) (outputSize :: Natural)
(hiddenSize :: Natural) (inputShape :: [Natural])
(outputShape :: [Natural]) (hxShape :: [Natural])
(parameters :: [Type]) (tensorParameters :: [Type])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
KnownNat batchSize, KnownNat hiddenSize,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
parameters
~ Parameters
(LSTM inputSize hiddenSize numLayers directionality dtype device),
Parameterized
(LSTM inputSize hiddenSize numLayers directionality dtype device),
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ATenTensor],
HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstmForward Bool
dropoutOn (LSTMWithConstInit lstmModel :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel@(LSTM LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
cc Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
forall {k} (shapeOrder :: RNNShapeOrder)
(directionality :: RNNDirectionality) (numLayers :: Natural)
(seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
(outputSize :: Natural) (hiddenSize :: Natural)
(inputShape :: [Natural]) (outputShape :: [Natural])
(hxShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstm
@shapeOrder
@directionality
@numLayers
@seqLen
@batchSize
@inputSize
@outputSize
@hiddenSize
@inputShape
@outputShape
@hxShape
@tensorParameters
@dtype
@device
(forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters forall a b. (a -> b) -> a -> b
$ LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel)
Double
dropoutProb
Bool
dropoutOn
(Tensor device dtype hxShape
cc', Tensor device dtype hxShape
hc')
Tensor device dtype inputShape
input
where
cc' :: Tensor device dtype hxShape
cc' =
forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
@'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
Bool
False
forall a b. (a -> b) -> a -> b
$ Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
cc
hc' :: Tensor device dtype hxShape
hc' =
forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
@'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
Bool
False
forall a b. (a -> b) -> a -> b
$ Tensor
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
hc
lstmForward Bool
dropoutOn (LSTMWithLearnedInit lstmModel :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel@(LSTM LSTMLayerStack
inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
cc Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
forall {k} (shapeOrder :: RNNShapeOrder)
(directionality :: RNNDirectionality) (numLayers :: Natural)
(seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
(outputSize :: Natural) (hiddenSize :: Natural)
(inputShape :: [Natural]) (outputShape :: [Natural])
(hxShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstm
@shapeOrder
@directionality
@numLayers
@seqLen
@batchSize
@inputSize
@outputSize
@hiddenSize
@inputShape
@outputShape
@hxShape
@tensorParameters
@dtype
@device
(forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters forall a b. (a -> b) -> a -> b
$ LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel)
Double
dropoutProb
Bool
dropoutOn
(Tensor device dtype hxShape
cc', Tensor device dtype hxShape
hc')
Tensor device dtype inputShape
input
where
cc' :: Tensor device dtype hxShape
cc' =
forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
@'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
Bool
False
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent
forall a b. (a -> b) -> a -> b
$ Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
cc
hc' :: Tensor device dtype hxShape
hc' =
forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
@'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
Bool
False
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent
forall a b. (a -> b) -> a -> b
$ Parameter
device
dtype
'[numLayers * NumberOfDirections directionality, hiddenSize]
hc
lstmForwardWithDropout,
lstmForwardWithoutDropout ::
forall
shapeOrder
batchSize
seqLen
directionality
initialization
numLayers
inputSize
outputSize
hiddenSize
inputShape
outputShape
hxShape
parameters
tensorParameters
dtype
device.
( KnownNat (NumberOfDirections directionality),
KnownNat numLayers,
KnownNat batchSize,
KnownNat hiddenSize,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
parameters ~ Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device),
Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
ATen.Castable (HList tensorParameters) [D.ATenTensor],
HMap' ToDependent parameters tensorParameters
) =>
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device ->
Tensor device dtype inputShape ->
( Tensor device dtype outputShape,
Tensor device dtype hxShape,
Tensor device dtype hxShape
)
lstmForwardWithDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
(seqLen :: Natural) (directionality :: RNNDirectionality)
(initialization :: RNNInitialization) (numLayers :: Natural)
(inputSize :: Natural) (outputSize :: Natural)
(hiddenSize :: Natural) (inputShape :: [Natural])
(outputShape :: [Natural]) (hxShape :: [Natural])
(parameters :: [Type]) (tensorParameters :: [Type])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
KnownNat batchSize, KnownNat hiddenSize,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
parameters
~ Parameters
(LSTM inputSize hiddenSize numLayers directionality dtype device),
Parameterized
(LSTM inputSize hiddenSize numLayers directionality dtype device),
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ATenTensor],
HMap' ToDependent parameters tensorParameters) =>
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstmForwardWithDropout =
forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
(seqLen :: Natural) (directionality :: RNNDirectionality)
(initialization :: RNNInitialization) (numLayers :: Natural)
(inputSize :: Natural) (outputSize :: Natural)
(hiddenSize :: Natural) (inputShape :: [Natural])
(outputShape :: [Natural]) (hxShape :: [Natural])
(parameters :: [Type]) (tensorParameters :: [Type])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
KnownNat batchSize, KnownNat hiddenSize,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
parameters
~ Parameters
(LSTM inputSize hiddenSize numLayers directionality dtype device),
Parameterized
(LSTM inputSize hiddenSize numLayers directionality dtype device),
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ATenTensor],
HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstmForward
@shapeOrder
@batchSize
@seqLen
@directionality
@initialization
@numLayers
@inputSize
@outputSize
@hiddenSize
@inputShape
@outputShape
@hxShape
@parameters
@tensorParameters
@dtype
@device
Bool
True
lstmForwardWithoutDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
(seqLen :: Natural) (directionality :: RNNDirectionality)
(initialization :: RNNInitialization) (numLayers :: Natural)
(inputSize :: Natural) (outputSize :: Natural)
(hiddenSize :: Natural) (inputShape :: [Natural])
(outputShape :: [Natural]) (hxShape :: [Natural])
(parameters :: [Type]) (tensorParameters :: [Type])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
KnownNat batchSize, KnownNat hiddenSize,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
parameters
~ Parameters
(LSTM inputSize hiddenSize numLayers directionality dtype device),
Parameterized
(LSTM inputSize hiddenSize numLayers directionality dtype device),
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ATenTensor],
HMap' ToDependent parameters tensorParameters) =>
LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstmForwardWithoutDropout =
forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
(seqLen :: Natural) (directionality :: RNNDirectionality)
(initialization :: RNNInitialization) (numLayers :: Natural)
(inputSize :: Natural) (outputSize :: Natural)
(hiddenSize :: Natural) (inputShape :: [Natural])
(outputShape :: [Natural]) (hxShape :: [Natural])
(parameters :: [Type]) (tensorParameters :: [Type])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
KnownNat batchSize, KnownNat hiddenSize,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
parameters
~ Parameters
(LSTM inputSize hiddenSize numLayers directionality dtype device),
Parameterized
(LSTM inputSize hiddenSize numLayers directionality dtype device),
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ATenTensor],
HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
inputSize
hiddenSize
numLayers
directionality
initialization
dtype
device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstmForward
@shapeOrder
@batchSize
@seqLen
@directionality
@initialization
@numLayers
@inputSize
@outputSize
@hiddenSize
@inputShape
@outputShape
@hxShape
@parameters
@tensorParameters
@dtype
@device
Bool
False