{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}

module Torch.GraduallyTyped.Tensor.MathOperations.Reduction where

import Control.Monad.Catch (MonadThrow)
import Control.Monad.State (execState, modify)
import Data.Bifunctor (Bifunctor (first), second)
import Data.Foldable (for_)
import Data.Kind (Constraint)
import qualified Data.Set as Set
import Data.Singletons (SingI (..), SingKind (..))
import Foreign.ForeignPtr (ForeignPtr)
import GHC.TypeLits (ErrorMessage, Nat, Symbol, TypeError)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked)
import Torch.GraduallyTyped.Shape.Class (ReplaceDimSizeImplF)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SSelectDim, SSelectDims, SelectDim (..), SelectDims (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import qualified Torch.Internal.Cast as ATen (cast1, cast3)
import qualified Torch.Internal.Class as ATen (Castable (cast), uncast)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Type as ATen (Tensor)
import Type.Errors.Pretty (type (%), type (<>))
import Prelude hiding (all, any)

-- $setup
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped
-- >>> import Prelude hiding (all, any)

type ReductionErrorMessage :: Symbol -> By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage

type ReductionErrorMessage reduction by dims =
  "Cannot apply '" <> reduction <> "' on the dimension matching"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "in the shape"
    % ""
    % "    '" <> dims <> "'."
    % ""

type ReductionCheckF ::
  Symbol ->
  By Symbol Nat ->
  [Dim (Name Symbol) (Size Nat)] ->
  Maybe [Dim (Name Symbol) (Size Nat)] ->
  [Dim (Name Symbol) (Size Nat)]
type family ReductionCheckF reduction by dims result where
  ReductionCheckF reduction by dims 'Nothing = TypeError (ReductionErrorMessage reduction by dims)
  ReductionCheckF _ _ _ ('Just dims') = dims'

type BoolReductionF ::
  Symbol ->
  SelectDim (By Symbol Nat) ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)]
type family BoolReductionF reduction selectDim shape where
  BoolReductionF _ 'UncheckedSelectDim _ = 'UncheckedShape
  BoolReductionF _ _ 'UncheckedShape = 'UncheckedShape
  BoolReductionF reduction ('SelectDim by) ('Shape dims) = 'Shape (ReductionCheckF reduction by dims (ReplaceDimSizeImplF by dims ('Size 1)))

-- | Tests if all element in input evaluates to True.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> shape = SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @4 :|: SNil
-- >>> (t, _) <- sRandn (TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) shape) g
-- >>> t' <- all =<< bool t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Bool)
--        ('Shape '[])
all ::
  forall requiresGradient layout device dataType shape m.
  MonadThrow m =>
  Tensor requiresGradient layout device dataType shape ->
  m (Tensor requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
all :: forall (requiresGradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> m (Tensor
        requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
all = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.all_t

-- | Reduces each row of the input tensor in the selected dimension to True if all elements in the row evaluate to True and False otherwise.
-- For a version that accepts non-singleton parameters see 'allDim'.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> shape = SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @4 :|: SNil
-- >>> (t, _) <- sRandn (TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) shape) g
-- >>> t' <- sAllDim (SSelectDim (SByIndex @1)) =<< bool t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Bool)
--        ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1)])
--
-- >>> sAllDim (SUncheckedSelectDim (ByIndex 3)) t
-- *** Exception: HasktorchException "Exception: Dimension out of range (expected to be in range of [-2, 1], but got 3)...
sAllDim ::
  forall selectDim gradient layout device dataType shape shape' m.
  (MonadThrow m, shape' ~ BoolReductionF "all" selectDim shape, Catch shape') =>
  SSelectDim selectDim ->
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device ('DataType 'Bool) shape')
sAllDim :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ BoolReductionF "all" selectDim shape,
 Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAllDim SSelectDim selectDim
by Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
by of
  ByName String
name ->
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr Dimname -> CBool -> IO (ForeignPtr Tensor)
ATen.all_tnb
      Tensor gradient layout device dataType shape
tensor
      String
name
      Bool
True -- keepDim
  ByIndex Integer
index ->
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.all_tlb
      Tensor gradient layout device dataType shape
tensor
      (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
      Bool
True -- keepDim

type AllDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]

type AllDimF selectDim shape = BoolReductionF "all" selectDim shape

-- | Reduces each row of the input tensor in the selected dimension to True if all elements in the row evaluate to True and False otherwise.
-- For a version that accepts singleton parameters see 'sAllDim'.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> type Shape' = 'Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 4) ]
-- >>> (t, _) <- randn @('Gradient 'WithoutGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @Shape' g
-- >>> t' <- allDim @('SelectDim ('ByIndex 1)) =<< bool t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Bool)
--        ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1)])
allDim ::
  forall selectDim gradient layout device dataType shape shape' m.
  (SingI selectDim, MonadThrow m, shape' ~ AllDimF selectDim shape, Catch shape') =>
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device ('DataType 'Bool) shape')
allDim :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDim, MonadThrow m, shape' ~ AllDimF selectDim shape,
 Catch shape') =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
allDim = forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ BoolReductionF "all" selectDim shape,
 Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAllDim (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)

-- | Tests if any element in input evaluates to True.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> shape = SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @4 :|: SNil
-- >>> (t, _) <- sRandn (TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) shape) g
-- >>> t' <- any =<< bool t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Bool)
--        ('Shape '[])
any ::
  forall requiresGradient layout device dataType shape m.
  MonadThrow m =>
  Tensor requiresGradient layout device dataType shape ->
  m (Tensor requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
any :: forall (requiresGradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> m (Tensor
        requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
any = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.any_t

type AnyDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]

type AnyDimF selectDim shape = BoolReductionF "any" selectDim shape

-- | Reduces each row of the input tensor in the selected dimension to True if any element in the row evaluates to True and False otherwise.
-- For a version that accepts non-singleton parameters see 'anyDim'.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> shape = SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @4 :|: SNil
-- >>> (t, _) <- sRandn (TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) shape) g
-- >>> t' <- sAnyDim (SSelectDim (SByIndex @1)) =<< bool t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Bool)
--        ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1)])
--
-- >>> sAnyDim (SUncheckedSelectDim (ByIndex 3)) t
-- *** Exception: HasktorchException "Exception: Dimension out of range (expected to be in range of [-2, 1], but got 3)...
sAnyDim ::
  forall selectDim gradient layout device shape dataType shape' m.
  (MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
  SSelectDim selectDim ->
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device ('DataType 'Bool) shape')
sAnyDim :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dataType :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAnyDim SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$
  case forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
selectDim of
    ByName String
name ->
      forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
        ForeignPtr Tensor
-> ForeignPtr Dimname -> CBool -> IO (ForeignPtr Tensor)
ATen.any_tnb
        Tensor gradient layout device dataType shape
tensor
        String
name
        Bool
True -- keepDim
    ByIndex Integer
index ->
      forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
        ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.any_tlb
        Tensor gradient layout device dataType shape
tensor
        (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
        Bool
True -- keepDim

-- | Reduces each row of the input tensor in the selected dimension to True if any element in the row evaluates to True and False otherwise.
-- For a version that accepts singleton parameters see 'sAnyDim'.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> type Shape' = 'Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 4) ]
-- >>> (t, _) <- randn @('Gradient 'WithoutGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @Shape' g
-- >>> t' <- anyDim @('SelectDim ('ByIndex 1)) =<< bool t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Bool)
--        ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1)])
anyDim ::
  forall selectDim gradient layout device dataType shape shape' m.
  (SingI selectDim, MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device ('DataType 'Bool) shape')
anyDim :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDim, MonadThrow m, shape' ~ AnyDimF selectDim shape,
 Catch shape') =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
anyDim = forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dataType :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAnyDim (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)

type family MeanSelectDimsF (bys :: [By Symbol Nat]) (dims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
  MeanSelectDimsF '[] dims = dims
  MeanSelectDimsF (by ': bys) dims = MeanSelectDimsF bys (ReductionCheckF "mean" by dims (ReplaceDimSizeImplF by dims ('Size 1)))

type family MeanF (selectDims :: SelectDims [By Symbol Nat]) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  MeanF 'UncheckedSelectDims _ = 'UncheckedShape
  MeanF _ 'UncheckedShape = 'UncheckedShape
  MeanF ('SelectDims bys) ('Shape dims) = 'Shape (MeanSelectDimsF bys dims)

-- | Reduces the mean value over each row of the input tensor in the dimensions selected by 'selectDims'.
-- For a version that accepts non-singleton parameters see 'meanDim'.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> shape = SShape $ SName @"batch" :&: SSize @8 :|: SName @"width" :&: SSize @224 :|: SName @"height" :&: SSize @224 :|: SNil
-- >>> (t, _) <- sRandn (TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) shape) g
-- >>> t' <- sMeanDims (SSelectDims $ SByName @"width" :|: SByName @"height" :|: SNil) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 8), 'Dim ('Name "width") ('Size 1),
--              'Dim ('Name "height") ('Size 1)])
--
-- >>> sMeanDims (SUncheckedSelectDims [ByName "feature"]) t
-- *** Exception: HasktorchException "Exception: Name 'feature' not found in Tensor['batch', 'width', 'height']...
sMeanDims ::
  forall selectDims gradient layout device dataType shape shape' m.
  (MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
  SSelectDims selectDims ->
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape')
sMeanDims :: forall (selectDims :: SelectDims [By Symbol Nat])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
SSelectDims selectDims
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sMeanDims SSelectDims selectDims
bys Tensor gradient layout device dataType shape
tensor =
  let bys' :: [By String Integer]
bys' = forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDims selectDims
bys
      (Set String
names, Set Integer
indexes) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState (forall a. Set a
Set.empty, forall a. Set a
Set.empty) forall a b. (a -> b) -> a -> b
$ do
        forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [By String Integer]
bys' forall a b. (a -> b) -> a -> b
$ \By String Integer
by -> do
          case By String Integer
by of
            ByName String
name -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
Set.insert String
name
            ByIndex Integer
index -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
Set.insert Integer
index
   in forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
        case (Set String
names, Set Integer
indexes) of
          (Set String
names', Set Integer
indexes')
            | forall a. Set a -> Bool
Set.null Set String
names' Bool -> Bool -> Bool
&& forall a. Set a -> Bool
Set.null Set Integer
indexes' ->
              do
                ForeignPtr Tensor
t :: ForeignPtr ATen.Tensor <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
ATen.cast Tensor gradient layout device dataType shape
tensor forall (f :: * -> *) a. Applicative f => a -> f a
pure
                forall a b r. Castable a b => b -> (a -> IO r) -> IO r
ATen.uncast ForeignPtr Tensor
t forall (f :: * -> *) a. Applicative f => a -> f a
pure
            | forall a. Set a -> Bool
Set.null Set String
names' ->
              forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set Integer -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanIndexes Set Integer
indexes') Tensor gradient layout device dataType shape
tensor
            | forall a. Set a -> Bool
Set.null Set Integer
indexes' ->
              forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set String -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanNames Set String
names') Tensor gradient layout device dataType shape
tensor
            | Bool
otherwise ->
              do
                ForeignPtr Tensor
t' :: ForeignPtr ATen.Tensor <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set Integer -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanIndexes Set Integer
indexes') Tensor gradient layout device dataType shape
tensor
                forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set String -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanNames Set String
names') ForeignPtr Tensor
t'
  where
    meanNames :: Set.Set String -> ForeignPtr ATen.Tensor -> IO (ForeignPtr ATen.Tensor)
    meanNames :: Set String -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanNames Set String
names ForeignPtr Tensor
tensor' =
      forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
        ForeignPtr Tensor
-> ForeignPtr DimnameList -> CBool -> IO (ForeignPtr Tensor)
ATen.mean_tNb
        ForeignPtr Tensor
tensor'
        (forall a. Set a -> [a]
Set.toList Set String
names)
        Bool
True -- keepDim
    meanIndexes :: Set.Set Integer -> ForeignPtr ATen.Tensor -> IO (ForeignPtr ATen.Tensor)
    meanIndexes :: Set Integer -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanIndexes Set Integer
indexes ForeignPtr Tensor
tensor' =
      forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
        ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.mean_tlb
        ForeignPtr Tensor
tensor'
        (forall a. Set a -> [a]
Set.toList Set Integer
indexes)
        Bool
True -- keepDim

-- | Reduce the mean value over each row of the input tensor in the dimensions selected by 'selectDims'.
-- For a version that accepts singleton parameters see 'sMeanDim'.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> type Shape' = 'Shape '[ 'Dim ('Name "batch") ('Size 8), 'Dim ('Name "feature") ('Size 4) ]
-- >>> (t, _) <- randn @('Gradient 'WithoutGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @Shape' g
-- >>> t' <- meanDims @('SelectDims '[ 'ByName "feature" ]) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 8),
--              'Dim ('Name "feature") ('Size 1)])
meanDims ::
  forall selectDims gradient layout device dataType shape shape' m.
  (SingI selectDims, MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape')
meanDims :: forall (selectDims :: SelectDims [By Symbol Nat])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDims, MonadThrow m, shape' ~ MeanF selectDims shape,
 Catch shape') =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
meanDims = forall (selectDims :: SelectDims [By Symbol Nat])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
SSelectDims selectDims
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sMeanDims (forall {k} (a :: k). SingI a => Sing a
sing @selectDims)

type DimPositiveMessage :: Symbol -> Dim (Name Symbol) (Size Nat) -> ErrorMessage

type DimPositiveMessage reduction dim =
  "Cannot apply '" <> reduction <> "' because the dimension"
    % ""
    % "    '" <> dim <> "'"
    % ""
    % "is not positive."

type DimPositiveF :: Symbol -> Dim (Name Symbol) (Size Nat) -> Constraint
type family DimPositiveF reduction dim where
  DimPositiveF _ ('Dim _ 'UncheckedSize) = ()
  DimPositiveF reduction ('Dim name ('Size 0)) = TypeError (DimPositiveMessage reduction ('Dim name ('Size 0)))
  DimPositiveF _ ('Dim _ ('Size _size)) = ()

type AllDimsPositiveImplF :: Symbol -> [Dim (Name Symbol) (Size Nat)] -> Constraint
type family AllDimsPositiveImplF reduction dims where
  AllDimsPositiveImplF _ '[] = ()
  AllDimsPositiveImplF reduction (dim ': dims) = (DimPositiveF reduction dim, AllDimsPositiveImplF reduction dims)

type AllDimsPositiveF :: Symbol -> Shape [Dim (Name Symbol) (Size Nat)] -> Constraint
type family AllDimsPositiveF reduction shape where
  AllDimsPositiveF _ 'UncheckedShape = ()
  AllDimsPositiveF reduction ('Shape dims) = AllDimsPositiveImplF reduction dims

type MeanAllCheckF :: Shape [Dim (Name Symbol) (Size Nat)] -> Constraint

type MeanAllCheckF shape = AllDimsPositiveF "meanAll" shape

-- | Reduces a tensor by calculating the mean value over all dimensions.
meanAll ::
  forall gradient layout device dataType shape.
  MeanAllCheckF shape =>
  Tensor gradient layout device dataType shape ->
  Tensor gradient layout device dataType ('Shape '[])
meanAll :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MeanAllCheckF shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType ('Shape '[])
meanAll = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mean_t

type ArgmaxF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]

type ArgmaxF selectDim shape = BoolReductionF "argmax" selectDim shape

-- | Argmax of a tensor given a dimension.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> spec = TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SNoName :&: SSize @2 :|: SNoName :&: SSize @5 :|: SNil)
-- >>> (t, _) <- sRandn spec g
-- >>> r <- argmax (SSelectDim $ SByIndex @1) t
-- >>> :type r
-- r :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Int64)
--        ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1)])
-- >>> r
-- Tensor Int64 [2,1] [[ 0],
--                     [ 2]]
argmax ::
  forall selectDims gradient layout device dataType shape shape' m.
  (MonadThrow m, shape' ~ ArgmaxF selectDims shape, Catch shape') =>
  SSelectDim selectDims ->
  Tensor gradient layout device dataType shape ->
  m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape')
argmax :: forall (selectDims :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ ArgmaxF selectDims shape, Catch shape') =>
SSelectDim selectDims
-> Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient)
        layout
        device
        ('DataType 'Int64)
        shape')
argmax SSelectDim selectDims
selectDim Tensor gradient layout device dataType shape
input =
  forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$
    let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDims
selectDim
     in case By String Integer
by of
          ByName String
name -> forall a. HasCallStack => a
undefined
          ByIndex Integer
index ->
            forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
              ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.argmax_tlb
              Tensor gradient layout device dataType shape
input
              (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
              Bool
True -- keepDim

type MaxAllCheckF :: Shape [Dim (Name Symbol) (Size Nat)] -> Constraint

type MaxAllCheckF shape = AllDimsPositiveF "maxAll" shape

maxAll ::
  forall gradient layout device dataType shape.
  MaxAllCheckF shape =>
  Tensor gradient layout device dataType shape ->
  Tensor gradient layout device dataType ('Shape '[])
maxAll :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MaxAllCheckF shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType ('Shape '[])
maxAll = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.max_t