{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}

module Torch.Typed.NN.Recurrent.LSTM where

import Data.Kind
import Data.Proxy (Proxy (..))
import Foreign.ForeignPtr
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.Environment
import System.IO.Unsafe
import qualified Torch.Autograd as A
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.NN as A
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (sqrt)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Recurrent.Auxiliary
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (tanh)

data
  LSTMLayerSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LSTMLayerSpec
  deriving (Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
[LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
[LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
show :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
showsPrec :: Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
Show, LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
== :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
Eq)

data
  LSTMLayer
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTMUnidirectionalLayer ::
    Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
    LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
  LSTMBidirectionalLayer ::
    Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
    LSTMLayer inputSize hiddenSize 'Bidirectional dtype device

deriving instance Show (LSTMLayer inputSize hiddenSize directionality dtype device)

instance Parameterized (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device) where
  type
    Parameters (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device) =
      '[ Parameter device dtype (LSTMWIShape hiddenSize inputSize),
         Parameter device dtype (LSTMWHShape hiddenSize inputSize),
         Parameter device dtype (LSTMBIShape hiddenSize inputSize),
         Parameter device dtype (LSTMBHShape hiddenSize inputSize)
       ]
  flattenParameters :: LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device))
flattenParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh) =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil
  replaceParameters :: LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device))
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
replaceParameters LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
_ (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh :. HList '[]
R:HListk[] Type
HNil) =
    forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh

instance Parameterized (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device) where
  type
    Parameters (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device) =
      '[ Parameter device dtype (LSTMWIShape hiddenSize inputSize),
         Parameter device dtype (LSTMWHShape hiddenSize inputSize),
         Parameter device dtype (LSTMBIShape hiddenSize inputSize),
         Parameter device dtype (LSTMBHShape hiddenSize inputSize),
         Parameter device dtype (LSTMWIShape hiddenSize inputSize),
         Parameter device dtype (LSTMWHShape hiddenSize inputSize),
         Parameter device dtype (LSTMBIShape hiddenSize inputSize),
         Parameter device dtype (LSTMBHShape hiddenSize inputSize)
       ]
  flattenParameters :: LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device))
flattenParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh') =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh' forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil
  replaceParameters :: LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device))
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
replaceParameters LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
_ (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh :. Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh' :. HList '[]
R:HListk[] Type
HNil) =
    forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh'

instance
  ( RandDTypeIsValid device dtype,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownDType dtype,
    KnownDevice device
  ) =>
  A.Randomizable
    (LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device)
    (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
  where
  sample :: LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device
-> IO (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
sample LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device
_ =
    forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)

instance
  ( RandDTypeIsValid device dtype,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownDType dtype,
    KnownDevice device
  ) =>
  A.Randomizable
    (LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device)
    (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
  where
  sample :: LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device
-> IO (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
sample LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device
_ =
    forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)

instance A.Parameterized (LSTMLayer inputSize hiddenSize directionality dtype device) where
  flattenParameters :: LSTMLayer inputSize hiddenSize directionality dtype device
-> [Parameter]
flattenParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh) =
    [ forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi,
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh,
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi,
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh
    ]
  flattenParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh') =
    [ forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi,
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh,
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi,
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh,
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi',
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh',
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi',
      forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh'
    ]
  _replaceParameters :: LSTMLayer inputSize hiddenSize directionality dtype device
-> ParamStream
     (LSTMLayer inputSize hiddenSize directionality dtype device)
_replaceParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh) = do
    Parameter
wi <- ParamStream Parameter
A.nextParameter
    Parameter
wh <- ParamStream Parameter
A.nextParameter
    Parameter
bi <- ParamStream Parameter
A.nextParameter
    Parameter
bh <- ParamStream Parameter
A.nextParameter
    forall (m :: Type -> Type) a. Monad m => a -> m a
return
      ( forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi)
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh)
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi)
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh)
      )
  _replaceParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh') = do
    Parameter
wi <- ParamStream Parameter
A.nextParameter
    Parameter
wh <- ParamStream Parameter
A.nextParameter
    Parameter
bi <- ParamStream Parameter
A.nextParameter
    Parameter
bh <- ParamStream Parameter
A.nextParameter
    Parameter
wi' <- ParamStream Parameter
A.nextParameter
    Parameter
wh' <- ParamStream Parameter
A.nextParameter
    Parameter
bi' <- ParamStream Parameter
A.nextParameter
    Parameter
bh' <- ParamStream Parameter
A.nextParameter
    forall (m :: Type -> Type) a. Monad m => a -> m a
return
      ( forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi)
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh)
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi)
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh)
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi')
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh')
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi')
          (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh')
      )

data
  LSTMLayerStackSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LSTMLayerStackSpec
  deriving (Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
show :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
showsPrec :: Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
Show, LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
== :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
Eq)

-- Input-to-hidden, hidden-to-hidden, and bias parameters for a mulilayered
-- (and optionally) bidirectional LSTM.
--
data
  LSTMLayerStack
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTMLayer1 ::
    LSTMLayer inputSize hiddenSize directionality dtype device ->
    LSTMLayerStack inputSize hiddenSize 1 directionality dtype device
  LSTMLayerK ::
    LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
    LSTMLayerStack inputSize hiddenSize (numLayers + 1) directionality dtype device

deriving instance Show (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)

class LSTMLayerStackParameterized (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
  type LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device :: [Type]
  lstmLayerStackFlattenParameters ::
    Proxy flag ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
    HList (LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device)
  lstmLayerStackReplaceParameters ::
    Proxy flag ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
    HList (LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device) ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device

instance
  Parameterized (LSTMLayer inputSize hiddenSize directionality dtype device) =>
  LSTMLayerStackParameterized 'False inputSize hiddenSize 1 directionality dtype device
  where
  type
    LSTMLayerStackParameters 'False inputSize hiddenSize 1 directionality dtype device =
      Parameters (LSTMLayer inputSize hiddenSize directionality dtype device)
  lstmLayerStackFlattenParameters :: Proxy 'False
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'False inputSize hiddenSize 1 directionality dtype device)
lstmLayerStackFlattenParameters Proxy 'False
_ (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer) = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer
  lstmLayerStackReplaceParameters :: Proxy 'False
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'False inputSize hiddenSize 1 directionality dtype device)
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
lstmLayerStackReplaceParameters Proxy 'False
_ (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer) HList
  (LSTMLayerStackParameters
     'False inputSize hiddenSize 1 directionality dtype device)
parameters = forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer HList
  (LSTMLayerStackParameters
     'False inputSize hiddenSize 1 directionality dtype device)
parameters

instance
  ( Parameterized
      ( LSTMLayer
          (hiddenSize * NumberOfDirections directionality)
          hiddenSize
          directionality
          dtype
          device
      ),
    Parameterized (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device),
    HAppendFD
      (Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device))
      (Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device))
      ( Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
          ++ Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
      )
  ) =>
  LSTMLayerStackParameterized 'True inputSize hiddenSize numLayers directionality dtype device
  where
  type
    LSTMLayerStackParameters 'True inputSize hiddenSize numLayers directionality dtype device =
      Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
        ++ Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
  lstmLayerStackFlattenParameters :: Proxy 'True
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'True inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackFlattenParameters Proxy 'True
_ (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack) =
    let parameters :: HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer
        parameters' :: HList
  (Parameters
     (LSTMLayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device))
parameters' = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters @(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device) LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack
     in HList
  (Parameters
     (LSTMLayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device))
parameters' forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters
  lstmLayerStackReplaceParameters :: Proxy 'True
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'True inputSize hiddenSize numLayers directionality dtype device)
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstmLayerStackReplaceParameters Proxy 'True
_ (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack) HList
  (LSTMLayerStackParameters
     'True inputSize hiddenSize numLayers directionality dtype device)
parameters'' =
    let (HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 (numLayers - 1)) 'True 'True 'False)
     inputSize
     hiddenSize
     (numLayers - 1)
     directionality
     dtype
     device)
parameters', HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters) = forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList
  (LSTMLayerStackParameters
     'True inputSize hiddenSize numLayers directionality dtype device)
parameters''
        lstmLayer' :: LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters
        lstmLayerStack' :: LSTMLayerStack
  inputSize hiddenSize (numLayers - 1) directionality dtype device
lstmLayerStack' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters @(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device) LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 (numLayers - 1)) 'True 'True 'False)
     inputSize
     hiddenSize
     (numLayers - 1)
     directionality
     dtype
     device)
parameters'
     in forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer' LSTMLayerStack
  inputSize hiddenSize (numLayers - 1) directionality dtype device
lstmLayerStack'

instance
  ( 1 <= numLayers,
    (2 <=? numLayers) ~ flag,
    LSTMLayerStackParameterized flag inputSize hiddenSize numLayers directionality dtype device
  ) =>
  Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
  where
  type
    Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device) =
      LSTMLayerStackParameters (2 <=? numLayers) inputSize hiddenSize numLayers directionality dtype device
  flattenParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (LSTMLayerStack
           inputSize hiddenSize numLayers directionality dtype device))
flattenParameters = forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerStackParameterized
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackFlattenParameters (forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)
  replaceParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (LSTMLayerStack
           inputSize hiddenSize numLayers directionality dtype device))
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
replaceParameters = forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerStackParameterized
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstmLayerStackReplaceParameters (forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)

class LSTMLayerStackRandomizable (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
  lstmLayerStackSample ::
    Proxy flag ->
    LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device ->
    IO (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)

instance
  ( A.Randomizable
      (LSTMLayerSpec inputSize hiddenSize directionality dtype device)
      (LSTMLayer inputSize hiddenSize directionality dtype device)
  ) =>
  LSTMLayerStackRandomizable 'False inputSize hiddenSize 1 directionality dtype device
  where
  lstmLayerStackSample :: Proxy 'False
-> LSTMLayerStackSpec
     inputSize hiddenSize 1 directionality dtype device
-> IO
     (LSTMLayerStack inputSize hiddenSize 1 directionality dtype device)
lstmLayerStackSample Proxy 'False
_ LSTMLayerStackSpec
  inputSize hiddenSize 1 directionality dtype device
_ = forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall spec f. Randomizable spec f => spec -> IO f
sample forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
LSTMLayerSpec @inputSize @hiddenSize @directionality @dtype @device)

instance
  ( 1 <= numLayers,
    A.Randomizable
      (LSTMLayerSpec (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
      (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device),
    A.Randomizable
      (LSTMLayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
      (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
  ) =>
  LSTMLayerStackRandomizable 'True inputSize hiddenSize numLayers directionality dtype device
  where
  lstmLayerStackSample :: Proxy 'True
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackSample Proxy 'True
_ LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
_ =
    forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall spec f. Randomizable spec f => spec -> IO f
sample forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
LSTMLayerSpec @(hiddenSize * NumberOfDirections directionality) @hiddenSize @directionality @dtype @device)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ( forall spec f. Randomizable spec f => spec -> IO f
sample
              @(LSTMLayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
              @(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
              forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStackSpec
          )

instance
  ( 1 <= numLayers,
    (2 <=? numLayers) ~ flag,
    RandDTypeIsValid device dtype,
    KnownDType dtype,
    KnownDevice device,
    LSTMLayerStackRandomizable flag inputSize hiddenSize numLayers directionality dtype device
  ) =>
  Randomizable
    (LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
    (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
  where
  sample :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
sample = forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerStackRandomizable
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackSample (forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)

instance A.Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device) where
  flattenParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
flattenParameters (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer) =
    forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayer inputSize hiddenSize directionality dtype device
layer
  flattenParameters (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer) =
    forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack
      forall a. [a] -> [a] -> [a]
++ forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer
  _replaceParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
_replaceParameters (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer) = do
    LSTMLayer inputSize hiddenSize directionality dtype device
layer' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayer inputSize hiddenSize directionality dtype device
layer
    forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer'
  _replaceParameters (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer) = do
    LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack
    LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer
    forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack' LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer'

newtype
  LSTMSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LSTMSpec DropoutSpec
  deriving (Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LSTMSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
show :: LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
showsPrec :: Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
Show, forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
  (LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
  (LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
$cfrom :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
Generic)

data
  LSTM
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTM ::
    (1 <= numLayers) =>
    { forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack :: LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device,
      forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_dropout :: Dropout
    } ->
    LSTM inputSize hiddenSize numLayers directionality dtype device

deriving instance Show (LSTM inputSize hiddenSize numLayers directionality dtype device)

instance
  (1 <= numLayers) =>
  Generic (LSTM inputSize hiddenSize numLayers directionality dtype device)
  where
  type
    Rep (LSTM inputSize hiddenSize numLayers directionality dtype device) =
      Rec0 (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 Dropout
  from :: forall x.
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTM inputSize hiddenSize numLayers directionality dtype device) x
from (LSTM {Dropout
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
..}) = forall k i c (p :: k). c -> K1 i c p
K1 LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Dropout
lstm_dropout
  to :: forall x.
Rep
  (LSTM inputSize hiddenSize numLayers directionality dtype device) x
-> LSTM inputSize hiddenSize numLayers directionality dtype device
to (K1 LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layerStack :*: K1 Dropout
dropout) = forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
LSTM LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layerStack Dropout
dropout

instance
  ( 1 <= numLayers,
    Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device))
      (Parameters Dropout)
      ( Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
          ++ Parameters Dropout
      )
  ) =>
  Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device)

-- TODO: when we have cannonical initializers do this correctly:
-- https://github.com/pytorch/pytorch/issues/9221
-- https://discuss.pytorch.org/t/initializing-rnn-gru-and-lstm-correctly/23605

instance A.Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device) where
  flattenParameters :: LSTM inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
flattenParameters LSTM {Dropout
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
..} = forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack
  _replaceParameters :: LSTM inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
_replaceParameters LSTM {Dropout
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
..} = do
    LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack
    forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
      LSTM
        { lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack = LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack',
          Dropout
lstm_dropout :: Dropout
lstm_dropout :: Dropout
..
        }

-- | Helper to do xavier uniform initializations on weight matrices and
-- orthagonal initializations for the gates. (When implemented.)
xavierUniformLSTM ::
  forall device dtype hiddenSize featureSize.
  ( KnownDType dtype,
    KnownNat hiddenSize,
    KnownNat featureSize,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM :: forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM = do
  Tensor device dtype '[4 * hiddenSize, featureSize]
init <- forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn :: IO (Tensor device dtype '[4 * hiddenSize, featureSize])
  forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
    forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME
      (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype '[4 * hiddenSize, featureSize]
init)
      (Float
5.0 forall a. Fractional a => a -> a -> a
/ Float
3)
      (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> [Int]
shape @device @dtype @'[4 * hiddenSize, featureSize] Tensor device dtype '[4 * hiddenSize, featureSize]
init)

instance
  ( KnownDType dtype,
    KnownDevice device,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownNat (NumberOfDirections directionality),
    RandDTypeIsValid device dtype,
    A.Randomizable
      (LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
      (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device),
    1 <= numLayers
  ) =>
  A.Randomizable
    (LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
    (LSTM inputSize hiddenSize numLayers directionality dtype device)
  where
  sample :: LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
sample (LSTMSpec DropoutSpec
dropoutSpec) =
    forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
LSTM
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStackSpec @inputSize @hiddenSize @numLayers @directionality @dtype @device)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropoutSpec

-- | A specification for a long, short-term memory layer.
data
  LSTMWithInitSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (initialization :: RNNInitialization)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases and cell states.
  LSTMWithZerosInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
    LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases
  --   and user-provided cell states.
  LSTMWithConstInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
    -- | The initial values of the memory cell
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    -- | The initial values of the hidden state
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases
  --   and learned cell states.
  LSTMWithLearnedInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
    -- | The initial (learnable)
    -- values of the memory cell
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    -- | The initial (learnable)
    -- values of the hidden state
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device

deriving instance Show (LSTMWithInitSpec inputSize hiddenSize numLayers directionality initialization dtype device)

-- | A long, short-term memory layer with either fixed initial
-- states for the memory cells and hidden state or learnable
-- inital states for the memory cells and hidden state.
data
  LSTMWithInit
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (initialization :: RNNInitialization)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTMWithConstInit ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    { forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device,
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize],
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
    } ->
    LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'ConstantInitialization
      dtype
      device
  LSTMWithLearnedInit ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    { forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device,
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize],
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
    } ->
    LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'LearnedInitialization
      dtype
      device

deriving instance Show (LSTMWithInit inputSize hiddenSize numLayers directionality initialization dtype device)

instance Generic (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) where
  type
    Rep (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) =
      Rec0 (LSTM inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 (Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
        :*: Rec0 (Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
  from :: forall x.
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Rep
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
     x
from (LSTMWithConstInit {Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..}) = forall k i c (p :: k). c -> K1 i c p
K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h
  to :: forall x.
Rep
  (LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device)
  x
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
to (K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstm :*: K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c :*: K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) = forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
LSTMWithConstInit LSTM inputSize hiddenSize numLayers directionality dtype device
lstm Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance Generic (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) where
  type
    Rep (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) =
      Rec0 (LSTM inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 (Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
        :*: Rec0 (Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
  from :: forall x.
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Rep
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
     x
from (LSTMWithLearnedInit {Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..}) = forall k i c (p :: k). c -> K1 i c p
K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall k i c (p :: k). c -> K1 i c p
K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h
  to :: forall x.
Rep
  (LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device)
  x
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
to (K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstm :*: K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c :*: K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) = forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
LSTMWithLearnedInit LSTM inputSize hiddenSize numLayers directionality dtype device
lstm Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance
  ( Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device))
      '[]
      (Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device) ++ '[])
  ) =>
  Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)

instance
  ( Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device))
      '[ Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize],
         Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]
       ]
      ( Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device)
          ++ '[ Parameter
                  device
                  dtype
                  '[numLayers * NumberOfDirections directionality, hiddenSize],
                Parameter
                  device
                  dtype
                  '[numLayers * NumberOfDirections directionality, hiddenSize]
              ]
      )
  ) =>
  Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)

instance
  ( KnownNat hiddenSize,
    KnownNat numLayers,
    KnownNat (NumberOfDirections directionality),
    KnownDType dtype,
    KnownDevice device,
    A.Randomizable
      (LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
      (LSTM inputSize hiddenSize numLayers directionality dtype device)
  ) =>
  A.Randomizable
    (LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
    (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
  where
  sample :: LSTMWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> IO
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
sample (LSTMWithZerosInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec) =
    forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
LSTMWithConstInit
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
  sample (LSTMWithConstInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
    forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
LSTMWithConstInit
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance
  ( KnownNat hiddenSize,
    KnownNat numLayers,
    KnownNat (NumberOfDirections directionality),
    KnownDType dtype,
    KnownDevice device,
    A.Randomizable
      (LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
      (LSTM inputSize hiddenSize numLayers directionality dtype device)
  ) =>
  A.Randomizable
    (LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
    (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
  where
  sample :: LSTMWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> IO
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
sample s :: LSTMWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
s@(LSTMWithLearnedInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
    forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
LSTMWithLearnedInit
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h)

instance A.Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality initialization dtype device) where
  flattenParameters :: LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> [Parameter]
flattenParameters LSTMWithConstInit {Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} =
    forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm
  flattenParameters LSTMWithLearnedInit {Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} =
    forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm
      forall a. [a] -> [a] -> [a]
++ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam [Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c, Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h]
  _replaceParameters :: LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> ParamStream
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        initialization
        dtype
        device)
_replaceParameters LSTMWithConstInit {Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} = do
    LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm
    forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
      LSTMWithConstInit
        { lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm = LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm',
          Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..
        }
  _replaceParameters LSTMWithLearnedInit {Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
..} = do
    LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm' <- forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm
    Parameter
lstmWithLearnedInit_c' <- ParamStream Parameter
A.nextParameter
    Parameter
lstmWithLearnedInit_h' <- ParamStream Parameter
A.nextParameter
    forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
      LSTMWithLearnedInit
        { lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm = LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm',
          lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c = forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
lstmWithLearnedInit_c',
          lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h = forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
lstmWithLearnedInit_h'
        }

lstmForward ::
  forall
    shapeOrder
    batchSize
    seqLen
    directionality
    initialization
    numLayers
    inputSize
    outputSize
    hiddenSize
    inputShape
    outputShape
    hxShape
    parameters
    tensorParameters
    dtype
    device.
  ( KnownNat (NumberOfDirections directionality),
    KnownNat numLayers,
    KnownNat batchSize,
    KnownNat hiddenSize,
    KnownRNNShapeOrder shapeOrder,
    KnownRNNDirectionality directionality,
    outputSize ~ (hiddenSize * NumberOfDirections directionality),
    inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
    outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
    hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
    parameters ~ Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device),
    Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
    tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
    ATen.Castable (HList tensorParameters) [D.ATenTensor],
    HMap' ToDependent parameters tensorParameters
  ) =>
  Bool ->
  LSTMWithInit
    inputSize
    hiddenSize
    numLayers
    directionality
    initialization
    dtype
    device ->
  Tensor device dtype inputShape ->
  ( Tensor device dtype outputShape,
    Tensor device dtype hxShape,
    Tensor device dtype hxShape
  )
lstmForward :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForward Bool
dropoutOn (LSTMWithConstInit lstmModel :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel@(LSTM LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
  forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hxShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstm
    @shapeOrder
    @directionality
    @numLayers
    @seqLen
    @batchSize
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @tensorParameters
    @dtype
    @device
    (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters forall a b. (a -> b) -> a -> b
$ LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel)
    Double
dropoutProb
    Bool
dropoutOn
    (Tensor device dtype hxShape
cc', Tensor device dtype hxShape
hc')
    Tensor device dtype inputShape
input
  where
    cc' :: Tensor device dtype hxShape
cc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc
    hc' :: Tensor device dtype hxShape
hc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc
lstmForward Bool
dropoutOn (LSTMWithLearnedInit lstmModel :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel@(LSTM LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
  forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hxShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstm
    @shapeOrder
    @directionality
    @numLayers
    @seqLen
    @batchSize
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @tensorParameters
    @dtype
    @device
    (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters forall a b. (a -> b) -> a -> b
$ LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel)
    Double
dropoutProb
    Bool
dropoutOn
    (Tensor device dtype hxShape
cc', Tensor device dtype hxShape
hc')
    Tensor device dtype inputShape
input
  where
    cc' :: Tensor device dtype hxShape
cc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent
        forall a b. (a -> b) -> a -> b
$ Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc
    hc' :: Tensor device dtype hxShape
hc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent
        forall a b. (a -> b) -> a -> b
$ Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc

lstmForwardWithDropout,
  lstmForwardWithoutDropout ::
    forall
      shapeOrder
      batchSize
      seqLen
      directionality
      initialization
      numLayers
      inputSize
      outputSize
      hiddenSize
      inputShape
      outputShape
      hxShape
      parameters
      tensorParameters
      dtype
      device.
    ( KnownNat (NumberOfDirections directionality),
      KnownNat numLayers,
      KnownNat batchSize,
      KnownNat hiddenSize,
      KnownRNNShapeOrder shapeOrder,
      KnownRNNDirectionality directionality,
      outputSize ~ (hiddenSize * NumberOfDirections directionality),
      inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
      outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
      hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
      parameters ~ Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device),
      Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
      tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
      ATen.Castable (HList tensorParameters) [D.ATenTensor],
      HMap' ToDependent parameters tensorParameters
    ) =>
    LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      initialization
      dtype
      device ->
    Tensor device dtype inputShape ->
    ( Tensor device dtype outputShape,
      Tensor device dtype hxShape,
      Tensor device dtype hxShape
    )
-- ^ Forward propagate the `LSTM` module and apply dropout on the outputs of each layer.
--
-- >>> input :: CPUTensor 'D.Float '[5,16,10] <- randn
-- >>> spec = LSTMWithZerosInitSpec @10 @30 @3 @'Bidirectional @'D.Float @'( 'D.CPU, 0) (LSTMSpec (DropoutSpec 0.5))
-- >>> model <- A.sample spec
-- >>> :t lstmForwardWithDropout @'BatchFirst model input
-- lstmForwardWithDropout @'BatchFirst model input
--   :: (Tensor '( 'D.CPU, 0) 'D.Float '[5, 16, 60],
--       Tensor '( 'D.CPU, 0) 'D.Float '[6, 5, 30],
--       Tensor '( 'D.CPU, 0) 'D.Float '[6, 5, 30])
-- >>> (a,b,c) = lstmForwardWithDropout @'BatchFirst model input
-- >>> ((dtype a, shape a), (dtype b, shape b), (dtype c, shape c))
-- ((Float,[5,16,60]),(Float,[6,5,30]),(Float,[6,5,30]))
lstmForwardWithDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForwardWithDropout =
  forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForward
    @shapeOrder
    @batchSize
    @seqLen
    @directionality
    @initialization
    @numLayers
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @parameters
    @tensorParameters
    @dtype
    @device
    Bool
True
-- ^ Forward propagate the `LSTM` module (without applying dropout on the outputs of each layer).
--
-- >>> input :: CPUTensor 'D.Float '[5,16,10] <- randn
-- >>> spec = LSTMWithZerosInitSpec @10 @30 @3 @'Bidirectional @'D.Float @'( 'D.CPU, 0) (LSTMSpec (DropoutSpec 0.5))
-- >>> model <- A.sample spec
-- >>> :t lstmForwardWithoutDropout @'BatchFirst model input
-- lstmForwardWithoutDropout @'BatchFirst model input
--   :: (Tensor '( 'D.CPU, 0) 'D.Float '[5, 16, 60],
--       Tensor '( 'D.CPU, 0) 'D.Float '[6, 5, 30],
--       Tensor '( 'D.CPU, 0) 'D.Float '[6, 5, 30])
-- >>> (a,b,c) = lstmForwardWithoutDropout @'BatchFirst model input
-- >>> ((dtype a, shape a), (dtype b, shape b), (dtype c, shape c))
-- ((Float,[5,16,60]),(Float,[6,5,30]),(Float,[6,5,30]))
lstmForwardWithoutDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForwardWithoutDropout =
  forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForward
    @shapeOrder
    @batchSize
    @seqLen
    @directionality
    @initialization
    @numLayers
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @parameters
    @tensorParameters
    @dtype
    @device
    Bool
False