{-# LANGUAGE RankNTypes #-}

module Torch.GraduallyTyped.NN.Functional.Activation where

import Control.Monad.Catch (MonadThrow)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (addScalar, mulScalar, powScalar, tanh)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.Internal.Cast (cast1, cast2, cast3)
import qualified Torch.Internal.Managed.Native as ATen
import Torch.Scalar (Scalar)
import Prelude (Float, pure, ($), (*), (+), (.), (/))
import qualified Prelude (pi, sqrt)
import Torch.Internal.GC (unsafeThrowableIO)

-- $setup
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped

-- | Thresholds each element of the input Tensor.
threshold ::
  forall threshold value gradient layout device dataType shape m.
  (Scalar threshold, Scalar value, MonadThrow m) =>
  -- | threshold
  threshold ->
  -- | value
  value ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
threshold :: forall threshold value (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar threshold, Scalar value, MonadThrow m) =>
threshold
-> value
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
threshold threshold
thresholdValue value
value Tensor gradient layout device dataType shape
tensor =
  forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.threshold_tss Tensor gradient layout device dataType shape
tensor threshold
thresholdValue value
value

-- | Applies the rectified linear unit function element-wise, that is,
-- \[
-- \text{ReLU}(x) = max(0, x).
-- \]
relu ::
  forall gradient layout device dataType shape.
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  Tensor gradient layout device dataType shape
relu :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
relu = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.relu_t

-- | Applies the gaussian error linear unit function element-wise.
gelu ::
  forall gradient layout device dataType shape.
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  Tensor gradient layout device dataType shape
gelu :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
gelu = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.gelu_t

-- | Applies the gaussian error linear unit function element-wise.
--
-- This is the implementation of the GELU activation function from
-- Google's BERT repo (and coincidentally also from OpenAI's GPT).
-- See also https://arxiv.org/abs/1606.08415.
--
-- >>> t <- sFull (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SNil)) 0.5
-- >>> t' <- geluNew t
-- >>> fromTensor @Float t'
-- 0.345714
geluNew ::
  forall gradient layout device dataType shape m.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
geluNew :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
geluNew Tensor gradient layout device dataType shape
x = do
  Tensor gradient layout device dataType shape
xHalfed <- Tensor gradient layout device dataType shape
x forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
0.5 :: Float)
  Tensor gradient layout device dataType shape
xCubed <- Tensor gradient layout device dataType shape
x forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`powScalar` (Float
3.0 :: Float)
  Tensor gradient layout device dataType shape
xCubedScaled <- Tensor gradient layout device dataType shape
xCubed forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
0.044715 :: Float)
  Tensor gradient layout device dataType shape
x' <- (Tensor gradient layout device dataType shape
x forall a. Num a => a -> a -> a
+ Tensor gradient layout device dataType shape
xCubedScaled) forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` forall a. Floating a => a -> a
Prelude.sqrt ((Float
2 :: Float) forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a
Prelude.pi)
  Tensor gradient layout device dataType shape
x'' <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
tanh Tensor gradient layout device dataType shape
x' forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`addScalar` (Float
1 :: Float)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
xHalfed forall a. Num a => a -> a -> a
* Tensor gradient layout device dataType shape
x''

-- | Applies the HardTanh function element-wise.
hardtanh ::
  forall minValue maxValue gradient layout device dataType shape m.
  (Scalar minValue, Scalar maxValue, MonadThrow m) =>
  -- | minimum value
  minValue ->
  -- | maximum value
  maxValue ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
hardtanh :: forall threshold value (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar threshold, Scalar value, MonadThrow m) =>
threshold
-> value
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
hardtanh minValue
minValue maxValue
maxValue Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.hardtanh_tss Tensor gradient layout device dataType shape
tensor minValue
minValue maxValue
maxValue

-- | Applies the hardswish function element-wise.
hardswish ::
  forall gradient layout device dataType shape.
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  Tensor gradient layout device dataType shape
hardswish :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
hardswish = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.hardswish_t

-- | Applies the exponential linear unit function element-wise, with alpha input,
-- \[
-- \text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1)).
-- \]
elu ::
  forall alpha gradient layout device dataType shape m.
  (Scalar alpha, MonadThrow m) =>
  -- | alpha value for ELU formulation
  alpha ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
elu :: forall alpha (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar alpha, MonadThrow m) =>
alpha
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
elu alpha
alpha Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.elu_ts Tensor gradient layout device dataType shape
tensor alpha
alpha

-- | Applies the scaled exponential linear unit function element-wise, that is,
-- \[
-- \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)),
-- \]
-- with \(\alpha = 1.6732632423543772848170429916717\)
-- and \(\text{scale}=1.0507009873554804934193349852946\).
selu ::
  forall gradient layout device dataType shape m.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
selu :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
selu = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.selu_t

-- | Applies the continuously differentiable exponential linear unit function element-wise, that is,
-- \[
-- \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)).
-- \]
celu ::
  forall alpha gradient layout device dataType shape m.
  (Scalar alpha, MonadThrow m) =>
  -- | alpha
  alpha ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
celu :: forall alpha (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar alpha, MonadThrow m) =>
alpha
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
celu alpha
alpha Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.celu_ts Tensor gradient layout device dataType shape
tensor alpha
alpha

-- | Applies the element-wise function:
-- \[
-- \text{LeakyReLU}(x) = \max(0,x) + \text{negativeSlope} * \min(0,x),
-- \]
-- the the angle of the negative slope can be controlled.
-- A typical value for it is 0.01.
leakyRelu ::
  forall negativeSlope gradient layout device dataType shape m.
  (Scalar negativeSlope, MonadThrow m) =>
  -- | negative slope
  negativeSlope ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
leakyRelu :: forall alpha (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar alpha, MonadThrow m) =>
alpha
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
leakyRelu negativeSlope
negativeSlope Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.leaky_relu_ts Tensor gradient layout device dataType shape
tensor negativeSlope
negativeSlope

-- | Applies the parameterized rectified linear unit function element-wise, that is,
-- \[
-- \text{PReLU}(x) = max(0, x) + \text{weight} * min(0, x).
-- \]
-- The weight parameter is typically learnable.
prelu ::
  forall gradient' gradient layout device dataType shape m.
  MonadThrow m =>
  -- | weight (typically learnable)
  Tensor gradient' layout device dataType shape ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  m (Tensor gradient layout device dataType shape)
prelu :: forall (gradient' :: Gradient RequiresGradient)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor gradient' layout device dataType shape
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
prelu Tensor gradient' layout device dataType shape
weight Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.prelu_tt Tensor gradient layout device dataType shape
tensor Tensor gradient' layout device dataType shape
weight