{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}

module Torch.GraduallyTyped.NN.Activation where

import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.NN.Functional.Activation (gelu, geluNew, relu)
import Torch.GraduallyTyped.NN.Functional.NonLinearActivation (SoftmaxF, logSoftmax, softmax)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Shape (By, SSelectDim, SelectDim)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (tanh)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Prelude hiding (tanh)

-- | 'Softmax' is a non-linear activation function.
data Softmax (selectDim :: SelectDim (By Symbol Nat)) where
  Softmax ::
    forall selectDim.
    {forall (selectDim :: SelectDim (By Symbol Nat)).
Softmax selectDim -> SSelectDim selectDim
softmaxSelectDim :: SSelectDim selectDim} ->
    Softmax selectDim
  deriving stock (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (selectDim :: SelectDim (By Symbol Nat)) x.
Rep (Softmax selectDim) x -> Softmax selectDim
forall (selectDim :: SelectDim (By Symbol Nat)) x.
Softmax selectDim -> Rep (Softmax selectDim) x
$cto :: forall (selectDim :: SelectDim (By Symbol Nat)) x.
Rep (Softmax selectDim) x -> Softmax selectDim
$cfrom :: forall (selectDim :: SelectDim (By Symbol Nat)) x.
Softmax selectDim -> Rep (Softmax selectDim) x
Generic)

type instance ModelSpec (Softmax selectDim) = Softmax selectDim

instance
  HasInitialize
    (Softmax selectDim)
    generatorDevice
    (Softmax selectDim)
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (Softmax selectDim)
-> Generator generatorDevice
-> m (Softmax selectDim, Generator generatorDevice)
initialize ModelSpec (Softmax selectDim)
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec (Softmax selectDim)
spec,)

instance HasStateDict (Softmax selectDim) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (Softmax selectDim)
-> StateDictKey -> m (Softmax selectDim)
fromStateDict ModelSpec (Softmax selectDim)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (Softmax selectDim)
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> Softmax selectDim -> m ()
toStateDict StateDictKey
_ Softmax selectDim
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  ( shape' ~ SoftmaxF selectDim shape,
    Catch shape',
    output ~ Tensor requiresGradient layout device dataType shape'
  ) =>
  HasForward
    (Softmax selectDim)
    (Tensor requiresGradient layout device dataType shape)
    generator
    output
    generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Softmax selectDim
-> Tensor requiresGradient layout device dataType shape
-> Generator generator
-> m (output, Generator generator)
forward Softmax {SSelectDim selectDim
softmaxSelectDim :: SSelectDim selectDim
softmaxSelectDim :: forall (selectDim :: SelectDim (By Symbol Nat)).
Softmax selectDim -> SSelectDim selectDim
..} Tensor requiresGradient layout device dataType shape
input Generator generator
g = do
    Tensor requiresGradient layout device dataType shape'
r <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
softmax SSelectDim selectDim
softmaxSelectDim Tensor requiresGradient layout device dataType shape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor requiresGradient layout device dataType shape'
r, Generator generator
g)

-- | 'LogSoftmax' is a non-linear activation function.
data LogSoftmax (selectDim :: SelectDim (By Symbol Nat)) where
  LogSoftmax ::
    forall selectDim.
    {forall (selectDim :: SelectDim (By Symbol Nat)).
LogSoftmax selectDim -> SSelectDim selectDim
logSoftmaxSelectDim :: SSelectDim selectDim} ->
    LogSoftmax selectDim
  deriving stock (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (selectDim :: SelectDim (By Symbol Nat)) x.
Rep (LogSoftmax selectDim) x -> LogSoftmax selectDim
forall (selectDim :: SelectDim (By Symbol Nat)) x.
LogSoftmax selectDim -> Rep (LogSoftmax selectDim) x
$cto :: forall (selectDim :: SelectDim (By Symbol Nat)) x.
Rep (LogSoftmax selectDim) x -> LogSoftmax selectDim
$cfrom :: forall (selectDim :: SelectDim (By Symbol Nat)) x.
LogSoftmax selectDim -> Rep (LogSoftmax selectDim) x
Generic)

type instance ModelSpec (LogSoftmax selectDim) = LogSoftmax selectDim

instance
  HasInitialize
    (LogSoftmax selectDim)
    generatorDevice
    (LogSoftmax selectDim)
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (LogSoftmax selectDim)
-> Generator generatorDevice
-> m (LogSoftmax selectDim, Generator generatorDevice)
initialize ModelSpec (LogSoftmax selectDim)
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec (LogSoftmax selectDim)
spec,)

instance HasStateDict (LogSoftmax selectDim) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (LogSoftmax selectDim)
-> StateDictKey -> m (LogSoftmax selectDim)
fromStateDict ModelSpec (LogSoftmax selectDim)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (LogSoftmax selectDim)
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> LogSoftmax selectDim -> m ()
toStateDict StateDictKey
_ LogSoftmax selectDim
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  ( shape' ~ SoftmaxF selectDim shape,
    Catch shape',
    output ~ Tensor requiresGradient layout device dataType shape'
  ) =>
  HasForward
    (LogSoftmax selectDim)
    (Tensor requiresGradient layout device dataType shape)
    generator
    output
    generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
LogSoftmax selectDim
-> Tensor requiresGradient layout device dataType shape
-> Generator generator
-> m (output, Generator generator)
forward LogSoftmax {SSelectDim selectDim
logSoftmaxSelectDim :: SSelectDim selectDim
logSoftmaxSelectDim :: forall (selectDim :: SelectDim (By Symbol Nat)).
LogSoftmax selectDim -> SSelectDim selectDim
..} Tensor requiresGradient layout device dataType shape
input Generator generator
g = do
    Tensor requiresGradient layout device dataType shape'
r <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax SSelectDim selectDim
logSoftmaxSelectDim Tensor requiresGradient layout device dataType shape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor requiresGradient layout device dataType shape'
r, Generator generator
g)

-- | 'Relu' is a step-wise linear activation function.
data Relu where
  Relu :: Relu
  deriving stock (Relu -> Relu -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Relu -> Relu -> Bool
$c/= :: Relu -> Relu -> Bool
== :: Relu -> Relu -> Bool
$c== :: Relu -> Relu -> Bool
Eq, Eq Relu
Relu -> Relu -> Bool
Relu -> Relu -> Ordering
Relu -> Relu -> Relu
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Relu -> Relu -> Relu
$cmin :: Relu -> Relu -> Relu
max :: Relu -> Relu -> Relu
$cmax :: Relu -> Relu -> Relu
>= :: Relu -> Relu -> Bool
$c>= :: Relu -> Relu -> Bool
> :: Relu -> Relu -> Bool
$c> :: Relu -> Relu -> Bool
<= :: Relu -> Relu -> Bool
$c<= :: Relu -> Relu -> Bool
< :: Relu -> Relu -> Bool
$c< :: Relu -> Relu -> Bool
compare :: Relu -> Relu -> Ordering
$ccompare :: Relu -> Relu -> Ordering
Ord, Int -> Relu -> ShowS
[Relu] -> ShowS
Relu -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Relu] -> ShowS
$cshowList :: [Relu] -> ShowS
show :: Relu -> String
$cshow :: Relu -> String
showsPrec :: Int -> Relu -> ShowS
$cshowsPrec :: Int -> Relu -> ShowS
Show, forall x. Rep Relu x -> Relu
forall x. Relu -> Rep Relu x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Relu x -> Relu
$cfrom :: forall x. Relu -> Rep Relu x
Generic)

type instance ModelSpec Relu = Relu

instance
  HasInitialize
    Relu
    generatorDevice
    Relu
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec Relu
-> Generator generatorDevice -> m (Relu, Generator generatorDevice)
initialize ModelSpec Relu
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec Relu
spec,)

instance HasStateDict Relu where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec Relu -> StateDictKey -> m Relu
fromStateDict ModelSpec Relu
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec Relu
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> Relu -> m ()
toStateDict StateDictKey
_ Relu
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  HasForward
    Relu
    (Tensor requiresGradient layout device dataType shape)
    generator
    (Tensor requiresGradient layout device dataType shape)
    generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Relu
-> Tensor requiresGradient layout device dataType shape
-> Generator generator
-> m (Tensor requiresGradient layout device dataType shape,
      Generator generator)
forward Relu
Relu Tensor requiresGradient layout device dataType shape
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
relu Tensor requiresGradient layout device dataType shape
input,)

-- | 'Gelu' is a non-linear activation function.
data Gelu where
  Gelu :: Gelu
  deriving stock (Gelu -> Gelu -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Gelu -> Gelu -> Bool
$c/= :: Gelu -> Gelu -> Bool
== :: Gelu -> Gelu -> Bool
$c== :: Gelu -> Gelu -> Bool
Eq, Eq Gelu
Gelu -> Gelu -> Bool
Gelu -> Gelu -> Ordering
Gelu -> Gelu -> Gelu
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Gelu -> Gelu -> Gelu
$cmin :: Gelu -> Gelu -> Gelu
max :: Gelu -> Gelu -> Gelu
$cmax :: Gelu -> Gelu -> Gelu
>= :: Gelu -> Gelu -> Bool
$c>= :: Gelu -> Gelu -> Bool
> :: Gelu -> Gelu -> Bool
$c> :: Gelu -> Gelu -> Bool
<= :: Gelu -> Gelu -> Bool
$c<= :: Gelu -> Gelu -> Bool
< :: Gelu -> Gelu -> Bool
$c< :: Gelu -> Gelu -> Bool
compare :: Gelu -> Gelu -> Ordering
$ccompare :: Gelu -> Gelu -> Ordering
Ord, Int -> Gelu -> ShowS
[Gelu] -> ShowS
Gelu -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Gelu] -> ShowS
$cshowList :: [Gelu] -> ShowS
show :: Gelu -> String
$cshow :: Gelu -> String
showsPrec :: Int -> Gelu -> ShowS
$cshowsPrec :: Int -> Gelu -> ShowS
Show, forall x. Rep Gelu x -> Gelu
forall x. Gelu -> Rep Gelu x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Gelu x -> Gelu
$cfrom :: forall x. Gelu -> Rep Gelu x
Generic)

type instance ModelSpec Gelu = Gelu

instance
  HasInitialize
    Gelu
    generatorDevice
    Gelu
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec Gelu
-> Generator generatorDevice -> m (Gelu, Generator generatorDevice)
initialize ModelSpec Gelu
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec Gelu
spec,)

instance HasStateDict Gelu where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec Gelu -> StateDictKey -> m Gelu
fromStateDict ModelSpec Gelu
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec Gelu
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> Gelu -> m ()
toStateDict StateDictKey
_ Gelu
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  HasForward
    Gelu
    (Tensor requiresGradient layout device dataType shape)
    generator
    (Tensor requiresGradient layout device dataType shape)
    generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Gelu
-> Tensor requiresGradient layout device dataType shape
-> Generator generator
-> m (Tensor requiresGradient layout device dataType shape,
      Generator generator)
forward Gelu
Gelu Tensor requiresGradient layout device dataType shape
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
gelu Tensor requiresGradient layout device dataType shape
input,)

-- | 'GeluNew' is a non-linear activation function.
-- It is a modified version of the 'Gelu' function.
data GeluNew where
  GeluNew :: GeluNew
  deriving stock (GeluNew -> GeluNew -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: GeluNew -> GeluNew -> Bool
$c/= :: GeluNew -> GeluNew -> Bool
== :: GeluNew -> GeluNew -> Bool
$c== :: GeluNew -> GeluNew -> Bool
Eq, Eq GeluNew
GeluNew -> GeluNew -> Bool
GeluNew -> GeluNew -> Ordering
GeluNew -> GeluNew -> GeluNew
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: GeluNew -> GeluNew -> GeluNew
$cmin :: GeluNew -> GeluNew -> GeluNew
max :: GeluNew -> GeluNew -> GeluNew
$cmax :: GeluNew -> GeluNew -> GeluNew
>= :: GeluNew -> GeluNew -> Bool
$c>= :: GeluNew -> GeluNew -> Bool
> :: GeluNew -> GeluNew -> Bool
$c> :: GeluNew -> GeluNew -> Bool
<= :: GeluNew -> GeluNew -> Bool
$c<= :: GeluNew -> GeluNew -> Bool
< :: GeluNew -> GeluNew -> Bool
$c< :: GeluNew -> GeluNew -> Bool
compare :: GeluNew -> GeluNew -> Ordering
$ccompare :: GeluNew -> GeluNew -> Ordering
Ord, Int -> GeluNew -> ShowS
[GeluNew] -> ShowS
GeluNew -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GeluNew] -> ShowS
$cshowList :: [GeluNew] -> ShowS
show :: GeluNew -> String
$cshow :: GeluNew -> String
showsPrec :: Int -> GeluNew -> ShowS
$cshowsPrec :: Int -> GeluNew -> ShowS
Show, forall x. Rep GeluNew x -> GeluNew
forall x. GeluNew -> Rep GeluNew x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep GeluNew x -> GeluNew
$cfrom :: forall x. GeluNew -> Rep GeluNew x
Generic)

type instance ModelSpec GeluNew = GeluNew

instance
  HasInitialize
    GeluNew
    generator
    GeluNew
    generator
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec GeluNew
-> Generator generator -> m (GeluNew, Generator generator)
initialize ModelSpec GeluNew
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec GeluNew
spec,)

instance HasStateDict GeluNew where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec GeluNew -> StateDictKey -> m GeluNew
fromStateDict ModelSpec GeluNew
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec GeluNew
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> GeluNew -> m ()
toStateDict StateDictKey
_ GeluNew
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  HasForward
    GeluNew
    (Tensor requiresGradient layout device dataType shape)
    generator
    (Tensor requiresGradient layout device dataType shape)
    generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GeluNew
-> Tensor requiresGradient layout device dataType shape
-> Generator generator
-> m (Tensor requiresGradient layout device dataType shape,
      Generator generator)
forward GeluNew
GeluNew Tensor requiresGradient layout device dataType shape
input Generator generator
g = do
    Tensor requiresGradient layout device dataType shape
output <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
geluNew Tensor requiresGradient layout device dataType shape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor requiresGradient layout device dataType shape
output, Generator generator
g)

-- | 'Tanh' is a non-linear activation function.
data Tanh where
  Tanh :: Tanh
  deriving stock (Tanh -> Tanh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tanh -> Tanh -> Bool
$c/= :: Tanh -> Tanh -> Bool
== :: Tanh -> Tanh -> Bool
$c== :: Tanh -> Tanh -> Bool
Eq, Eq Tanh
Tanh -> Tanh -> Bool
Tanh -> Tanh -> Ordering
Tanh -> Tanh -> Tanh
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Tanh -> Tanh -> Tanh
$cmin :: Tanh -> Tanh -> Tanh
max :: Tanh -> Tanh -> Tanh
$cmax :: Tanh -> Tanh -> Tanh
>= :: Tanh -> Tanh -> Bool
$c>= :: Tanh -> Tanh -> Bool
> :: Tanh -> Tanh -> Bool
$c> :: Tanh -> Tanh -> Bool
<= :: Tanh -> Tanh -> Bool
$c<= :: Tanh -> Tanh -> Bool
< :: Tanh -> Tanh -> Bool
$c< :: Tanh -> Tanh -> Bool
compare :: Tanh -> Tanh -> Ordering
$ccompare :: Tanh -> Tanh -> Ordering
Ord, Int -> Tanh -> ShowS
[Tanh] -> ShowS
Tanh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tanh] -> ShowS
$cshowList :: [Tanh] -> ShowS
show :: Tanh -> String
$cshow :: Tanh -> String
showsPrec :: Int -> Tanh -> ShowS
$cshowsPrec :: Int -> Tanh -> ShowS
Show, forall x. Rep Tanh x -> Tanh
forall x. Tanh -> Rep Tanh x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Tanh x -> Tanh
$cfrom :: forall x. Tanh -> Rep Tanh x
Generic)

type instance ModelSpec Tanh = Tanh

instance
  HasInitialize
    Tanh
    generator
    Tanh
    generator
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec Tanh
-> Generator generator -> m (Tanh, Generator generator)
initialize ModelSpec Tanh
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec Tanh
spec,)

instance HasStateDict Tanh where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec Tanh -> StateDictKey -> m Tanh
fromStateDict ModelSpec Tanh
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec Tanh
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> Tanh -> m ()
toStateDict StateDictKey
_ Tanh
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  HasForward
    Tanh
    (Tensor requiresGradient layout device dataType shape)
    generator
    (Tensor requiresGradient layout device dataType shape)
    generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Tanh
-> Tensor requiresGradient layout device dataType shape
-> Generator generator
-> m (Tensor requiresGradient layout device dataType shape,
      Generator generator)
forward Tanh
Tanh Tensor requiresGradient layout device dataType shape
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
tanh Tensor requiresGradient layout device dataType shape
input,)