{-# LANGUAGE RankNTypes #-}
module Torch.GraduallyTyped.NN.Functional.Activation where
import Control.Monad.Catch (MonadThrow)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (addScalar, mulScalar, powScalar, tanh)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.Internal.Cast (cast1, cast2, cast3)
import qualified Torch.Internal.Managed.Native as ATen
import Torch.Scalar (Scalar)
import Prelude (Float, pure, ($), (*), (+), (.), (/))
import qualified Prelude (pi, sqrt)
import Torch.Internal.GC (unsafeThrowableIO)
threshold ::
forall threshold value gradient layout device dataType shape m.
(Scalar threshold, Scalar value, MonadThrow m) =>
threshold ->
value ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
threshold :: forall threshold value (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar threshold, Scalar value, MonadThrow m) =>
threshold
-> value
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
threshold threshold
thresholdValue value
value Tensor gradient layout device dataType shape
tensor =
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.threshold_tss Tensor gradient layout device dataType shape
tensor threshold
thresholdValue value
value
relu ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape
relu :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
relu = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.relu_t
gelu ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape
gelu :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
gelu = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.gelu_t
geluNew ::
forall gradient layout device dataType shape m.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
geluNew :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
geluNew Tensor gradient layout device dataType shape
x = do
Tensor gradient layout device dataType shape
xHalfed <- Tensor gradient layout device dataType shape
x forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
0.5 :: Float)
Tensor gradient layout device dataType shape
xCubed <- Tensor gradient layout device dataType shape
x forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`powScalar` (Float
3.0 :: Float)
Tensor gradient layout device dataType shape
xCubedScaled <- Tensor gradient layout device dataType shape
xCubed forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
0.044715 :: Float)
Tensor gradient layout device dataType shape
x' <- (Tensor gradient layout device dataType shape
x forall a. Num a => a -> a -> a
+ Tensor gradient layout device dataType shape
xCubedScaled) forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` forall a. Floating a => a -> a
Prelude.sqrt ((Float
2 :: Float) forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a
Prelude.pi)
Tensor gradient layout device dataType shape
x'' <- forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
tanh Tensor gradient layout device dataType shape
x' forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`addScalar` (Float
1 :: Float)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
xHalfed forall a. Num a => a -> a -> a
* Tensor gradient layout device dataType shape
x''
hardtanh ::
forall minValue maxValue gradient layout device dataType shape m.
(Scalar minValue, Scalar maxValue, MonadThrow m) =>
minValue ->
maxValue ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
hardtanh :: forall threshold value (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar threshold, Scalar value, MonadThrow m) =>
threshold
-> value
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
hardtanh minValue
minValue maxValue
maxValue Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.hardtanh_tss Tensor gradient layout device dataType shape
tensor minValue
minValue maxValue
maxValue
hardswish ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape
hardswish :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
hardswish = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.hardswish_t
elu ::
forall alpha gradient layout device dataType shape m.
(Scalar alpha, MonadThrow m) =>
alpha ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
elu :: forall alpha (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar alpha, MonadThrow m) =>
alpha
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
elu alpha
alpha Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.elu_ts Tensor gradient layout device dataType shape
tensor alpha
alpha
selu ::
forall gradient layout device dataType shape m.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
selu :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
selu = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.selu_t
celu ::
forall alpha gradient layout device dataType shape m.
(Scalar alpha, MonadThrow m) =>
alpha ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
celu :: forall alpha (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar alpha, MonadThrow m) =>
alpha
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
celu alpha
alpha Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.celu_ts Tensor gradient layout device dataType shape
tensor alpha
alpha
leakyRelu ::
forall negativeSlope gradient layout device dataType shape m.
(Scalar negativeSlope, MonadThrow m) =>
negativeSlope ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
leakyRelu :: forall alpha (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar alpha, MonadThrow m) =>
alpha
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
leakyRelu negativeSlope
negativeSlope Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.leaky_relu_ts Tensor gradient layout device dataType shape
tensor negativeSlope
negativeSlope
prelu ::
forall gradient' gradient layout device dataType shape m.
MonadThrow m =>
Tensor gradient' layout device dataType shape ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
prelu :: forall (gradient' :: Gradient RequiresGradient)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor gradient' layout device dataType shape
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
prelu Tensor gradient' layout device dataType shape
weight Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.prelu_tt Tensor gradient layout device dataType shape
tensor Tensor gradient' layout device dataType shape
weight