{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Functional.NonLinearActivation where
import Control.Monad.Catch (MonadThrow)
import Data.Singletons (SingKind (..))
import GHC.TypeLits (Nat, Symbol, TypeError)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked)
import Torch.GraduallyTyped.Shape (By (..), Dim (..), GetDimImplF, Name (..), SSelectDim (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.Internal.Cast (cast2)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import Type.Errors.Pretty (type (%), type (<>))
type SoftMaxErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
"Cannot apply softmax on the dimension matching"
% ""
% " '" <> by <> "'"
% ""
% "in the shape"
% ""
% " '" <> dims <> "'."
% ""
type family SoftmaxCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe (Dim (Name Symbol) (Size Nat))) :: [Dim (Name Symbol) (Size Nat)] where
SoftmaxCheckF by dims 'Nothing = TypeError (SoftMaxErrorMessage by dims)
SoftmaxCheckF _ dims ('Just _) = dims
type family SoftmaxF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
SoftmaxF 'UncheckedSelectDim _ = 'UncheckedShape
SoftmaxF _ 'UncheckedShape = 'UncheckedShape
SoftmaxF ('SelectDim by) ('Shape dims) = 'Shape (SoftmaxCheckF by dims (GetDimImplF by dims))
softmax,
logSoftmax ::
forall selectDim gradient layout device dataType shape shape' m.
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
softmax :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
softmax SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
tensor =
case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
selectDim) of
ByName String
name -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.softmax_tn Tensor gradient layout device dataType shape
tensor String
name
ByIndex Integer
index -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.softmax_tl Tensor gradient layout device dataType shape
tensor (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
logSoftmax :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
tensor =
case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
selectDim) of
ByName String
name -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.log_softmax_tn Tensor gradient layout device dataType shape
tensor String
name
ByIndex Integer
index -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.log_softmax_tl Tensor gradient layout device dataType shape
tensor (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)