{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Functional.NonLinearActivation where

import Control.Monad.Catch (MonadThrow)
import Data.Singletons (SingKind (..))
import GHC.TypeLits (Nat, Symbol, TypeError)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked)
import Torch.GraduallyTyped.Shape (By (..), Dim (..), GetDimImplF, Name (..), SSelectDim (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.Internal.Cast (cast2)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import Type.Errors.Pretty (type (%), type (<>))

-- $setup
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped

type SoftMaxErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
  "Cannot apply softmax on the dimension matching"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "in the shape"
    % ""
    % "    '" <> dims <> "'."
    % ""

type family SoftmaxCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe (Dim (Name Symbol) (Size Nat))) :: [Dim (Name Symbol) (Size Nat)] where
  SoftmaxCheckF by dims 'Nothing = TypeError (SoftMaxErrorMessage by dims)
  SoftmaxCheckF _ dims ('Just _) = dims

type family SoftmaxF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  SoftmaxF 'UncheckedSelectDim _ = 'UncheckedShape
  SoftmaxF _ 'UncheckedShape = 'UncheckedShape
  SoftmaxF ('SelectDim by) ('Shape dims) = 'Shape (SoftmaxCheckF by dims (GetDimImplF by dims))

-- | Applies the softmax function that is defined as:
--
-- \[
-- \mathrm{Softmax}(\mathrm{input}_{i}) = \frac{\exp\left(\mathrm{input}_{i}\right)}{\sum_j \exp\left(\mathrm{input}_{j}\right)}
-- \]
--
-- Softmax is applied to all slices along 'selectDim',
-- and will re-scale them so that the elements lie in the range \([0, 1]\) and sum to \(1\):
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> (input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)) g
-- >>> result <- softmax (SSelectDim (SByName @"feature")) input
-- >>> :type result
-- result
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
softmax,
  logSoftmax ::
    forall selectDim gradient layout device dataType shape shape' m.
    (MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
    SSelectDim selectDim ->
    Tensor gradient layout device dataType shape ->
    m (Tensor gradient layout device dataType shape')
softmax :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
softmax SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
tensor =
  case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
selectDim) of
    ByName String
name -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.softmax_tn Tensor gradient layout device dataType shape
tensor String
name
    ByIndex Integer
index -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.softmax_tl Tensor gradient layout device dataType shape
tensor (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
logSoftmax :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
tensor =
  case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
selectDim) of
    ByName String
name -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.log_softmax_tn Tensor gradient layout device dataType shape
tensor String
name
    ByIndex Integer
index -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.log_softmax_tl Tensor gradient layout device dataType shape
tensor (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)