{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
module Torch.GraduallyTyped.Tensor.MathOperations.Reduction where
import Control.Monad.Catch (MonadThrow)
import Control.Monad.State (execState, modify)
import Data.Bifunctor (Bifunctor (first), second)
import Data.Foldable (for_)
import Data.Kind (Constraint)
import qualified Data.Set as Set
import Data.Singletons (SingI (..), SingKind (..))
import Foreign.ForeignPtr (ForeignPtr)
import GHC.TypeLits (ErrorMessage, Nat, Symbol, TypeError)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked)
import Torch.GraduallyTyped.Shape.Class (ReplaceDimSizeImplF)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SSelectDim, SSelectDims, SelectDim (..), SelectDims (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import qualified Torch.Internal.Cast as ATen (cast1, cast3)
import qualified Torch.Internal.Class as ATen (Castable (cast), uncast)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Type as ATen (Tensor)
import Type.Errors.Pretty (type (%), type (<>))
import Prelude hiding (all, any)
type ReductionErrorMessage :: Symbol -> By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage
type ReductionErrorMessage reduction by dims =
"Cannot apply '" <> reduction <> "' on the dimension matching"
% ""
% " '" <> by <> "'"
% ""
% "in the shape"
% ""
% " '" <> dims <> "'."
% ""
type ReductionCheckF ::
Symbol ->
By Symbol Nat ->
[Dim (Name Symbol) (Size Nat)] ->
Maybe [Dim (Name Symbol) (Size Nat)] ->
[Dim (Name Symbol) (Size Nat)]
type family ReductionCheckF reduction by dims result where
ReductionCheckF reduction by dims 'Nothing = TypeError (ReductionErrorMessage reduction by dims)
ReductionCheckF _ _ _ ('Just dims') = dims'
type BoolReductionF ::
Symbol ->
SelectDim (By Symbol Nat) ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)]
type family BoolReductionF reduction selectDim shape where
BoolReductionF _ 'UncheckedSelectDim _ = 'UncheckedShape
BoolReductionF _ _ 'UncheckedShape = 'UncheckedShape
BoolReductionF reduction ('SelectDim by) ('Shape dims) = 'Shape (ReductionCheckF reduction by dims (ReplaceDimSizeImplF by dims ('Size 1)))
all ::
forall requiresGradient layout device dataType shape m.
MonadThrow m =>
Tensor requiresGradient layout device dataType shape ->
m (Tensor requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
all :: forall (requiresGradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> m (Tensor
requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
all = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.all_t
sAllDim ::
forall selectDim gradient layout device dataType shape shape' m.
(MonadThrow m, shape' ~ BoolReductionF "all" selectDim shape, Catch shape') =>
SSelectDim selectDim ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device ('DataType 'Bool) shape')
sAllDim :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ BoolReductionF "all" selectDim shape,
Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAllDim SSelectDim selectDim
by Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
by of
ByName String
name ->
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr Dimname -> CBool -> IO (ForeignPtr Tensor)
ATen.all_tnb
Tensor gradient layout device dataType shape
tensor
String
name
Bool
True
ByIndex Integer
index ->
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.all_tlb
Tensor gradient layout device dataType shape
tensor
(forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
Bool
True
type AllDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type AllDimF selectDim shape = BoolReductionF "all" selectDim shape
allDim ::
forall selectDim gradient layout device dataType shape shape' m.
(SingI selectDim, MonadThrow m, shape' ~ AllDimF selectDim shape, Catch shape') =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device ('DataType 'Bool) shape')
allDim :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDim, MonadThrow m, shape' ~ AllDimF selectDim shape,
Catch shape') =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
allDim = forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ BoolReductionF "all" selectDim shape,
Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAllDim (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)
any ::
forall requiresGradient layout device dataType shape m.
MonadThrow m =>
Tensor requiresGradient layout device dataType shape ->
m (Tensor requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
any :: forall (requiresGradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> m (Tensor
requiresGradient layout device ('DataType 'Bool) ('Shape '[]))
any = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.any_t
type AnyDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type AnyDimF selectDim shape = BoolReductionF "any" selectDim shape
sAnyDim ::
forall selectDim gradient layout device shape dataType shape' m.
(MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
SSelectDim selectDim ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device ('DataType 'Bool) shape')
sAnyDim :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dataType :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAnyDim SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$
case forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
selectDim of
ByName String
name ->
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr Dimname -> CBool -> IO (ForeignPtr Tensor)
ATen.any_tnb
Tensor gradient layout device dataType shape
tensor
String
name
Bool
True
ByIndex Integer
index ->
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.any_tlb
Tensor gradient layout device dataType shape
tensor
(forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
Bool
True
anyDim ::
forall selectDim gradient layout device dataType shape shape' m.
(SingI selectDim, MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device ('DataType 'Bool) shape')
anyDim :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDim, MonadThrow m, shape' ~ AnyDimF selectDim shape,
Catch shape') =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
anyDim = forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dataType :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ AnyDimF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Bool) shape')
sAnyDim (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)
type family MeanSelectDimsF (bys :: [By Symbol Nat]) (dims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
MeanSelectDimsF '[] dims = dims
MeanSelectDimsF (by ': bys) dims = MeanSelectDimsF bys (ReductionCheckF "mean" by dims (ReplaceDimSizeImplF by dims ('Size 1)))
type family MeanF (selectDims :: SelectDims [By Symbol Nat]) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
MeanF 'UncheckedSelectDims _ = 'UncheckedShape
MeanF _ 'UncheckedShape = 'UncheckedShape
MeanF ('SelectDims bys) ('Shape dims) = 'Shape (MeanSelectDimsF bys dims)
sMeanDims ::
forall selectDims gradient layout device dataType shape shape' m.
(MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
SSelectDims selectDims ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
sMeanDims :: forall (selectDims :: SelectDims [By Symbol Nat])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
SSelectDims selectDims
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sMeanDims SSelectDims selectDims
bys Tensor gradient layout device dataType shape
tensor =
let bys' :: [By String Integer]
bys' = forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDims selectDims
bys
(Set String
names, Set Integer
indexes) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState (forall a. Set a
Set.empty, forall a. Set a
Set.empty) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [By String Integer]
bys' forall a b. (a -> b) -> a -> b
$ \By String Integer
by -> do
case By String Integer
by of
ByName String
name -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
Set.insert String
name
ByIndex Integer
index -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
Set.insert Integer
index
in forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
case (Set String
names, Set Integer
indexes) of
(Set String
names', Set Integer
indexes')
| forall a. Set a -> Bool
Set.null Set String
names' Bool -> Bool -> Bool
&& forall a. Set a -> Bool
Set.null Set Integer
indexes' ->
do
ForeignPtr Tensor
t :: ForeignPtr ATen.Tensor <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
ATen.cast Tensor gradient layout device dataType shape
tensor forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
ATen.uncast ForeignPtr Tensor
t forall (f :: * -> *) a. Applicative f => a -> f a
pure
| forall a. Set a -> Bool
Set.null Set String
names' ->
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set Integer -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanIndexes Set Integer
indexes') Tensor gradient layout device dataType shape
tensor
| forall a. Set a -> Bool
Set.null Set Integer
indexes' ->
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set String -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanNames Set String
names') Tensor gradient layout device dataType shape
tensor
| Bool
otherwise ->
do
ForeignPtr Tensor
t' :: ForeignPtr ATen.Tensor <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set Integer -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanIndexes Set Integer
indexes') Tensor gradient layout device dataType shape
tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 (Set String -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanNames Set String
names') ForeignPtr Tensor
t'
where
meanNames :: Set.Set String -> ForeignPtr ATen.Tensor -> IO (ForeignPtr ATen.Tensor)
meanNames :: Set String -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanNames Set String
names ForeignPtr Tensor
tensor' =
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr DimnameList -> CBool -> IO (ForeignPtr Tensor)
ATen.mean_tNb
ForeignPtr Tensor
tensor'
(forall a. Set a -> [a]
Set.toList Set String
names)
Bool
True
meanIndexes :: Set.Set Integer -> ForeignPtr ATen.Tensor -> IO (ForeignPtr ATen.Tensor)
meanIndexes :: Set Integer -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
meanIndexes Set Integer
indexes ForeignPtr Tensor
tensor' =
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.mean_tlb
ForeignPtr Tensor
tensor'
(forall a. Set a -> [a]
Set.toList Set Integer
indexes)
Bool
True
meanDims ::
forall selectDims gradient layout device dataType shape shape' m.
(SingI selectDims, MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
meanDims :: forall (selectDims :: SelectDims [By Symbol Nat])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDims, MonadThrow m, shape' ~ MeanF selectDims shape,
Catch shape') =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
meanDims = forall (selectDims :: SelectDims [By Symbol Nat])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ MeanF selectDims shape, Catch shape') =>
SSelectDims selectDims
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sMeanDims (forall {k} (a :: k). SingI a => Sing a
sing @selectDims)
type DimPositiveMessage :: Symbol -> Dim (Name Symbol) (Size Nat) -> ErrorMessage
type DimPositiveMessage reduction dim =
"Cannot apply '" <> reduction <> "' because the dimension"
% ""
% " '" <> dim <> "'"
% ""
% "is not positive."
type DimPositiveF :: Symbol -> Dim (Name Symbol) (Size Nat) -> Constraint
type family DimPositiveF reduction dim where
DimPositiveF _ ('Dim _ 'UncheckedSize) = ()
DimPositiveF reduction ('Dim name ('Size 0)) = TypeError (DimPositiveMessage reduction ('Dim name ('Size 0)))
DimPositiveF _ ('Dim _ ('Size _size)) = ()
type AllDimsPositiveImplF :: Symbol -> [Dim (Name Symbol) (Size Nat)] -> Constraint
type family AllDimsPositiveImplF reduction dims where
AllDimsPositiveImplF _ '[] = ()
AllDimsPositiveImplF reduction (dim ': dims) = (DimPositiveF reduction dim, AllDimsPositiveImplF reduction dims)
type AllDimsPositiveF :: Symbol -> Shape [Dim (Name Symbol) (Size Nat)] -> Constraint
type family AllDimsPositiveF reduction shape where
AllDimsPositiveF _ 'UncheckedShape = ()
AllDimsPositiveF reduction ('Shape dims) = AllDimsPositiveImplF reduction dims
type MeanAllCheckF :: Shape [Dim (Name Symbol) (Size Nat)] -> Constraint
type MeanAllCheckF shape = AllDimsPositiveF "meanAll" shape
meanAll ::
forall gradient layout device dataType shape.
MeanAllCheckF shape =>
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType ('Shape '[])
meanAll :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MeanAllCheckF shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType ('Shape '[])
meanAll = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mean_t
type ArgmaxF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type ArgmaxF selectDim shape = BoolReductionF "argmax" selectDim shape
argmax ::
forall selectDims gradient layout device dataType shape shape' m.
(MonadThrow m, shape' ~ ArgmaxF selectDims shape, Catch shape') =>
SSelectDim selectDims ->
Tensor gradient layout device dataType shape ->
m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape')
argmax :: forall (selectDims :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ ArgmaxF selectDims shape, Catch shape') =>
SSelectDim selectDims
-> Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Int64)
shape')
argmax SSelectDim selectDims
selectDim Tensor gradient layout device dataType shape
input =
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$
let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDims
selectDim
in case By String Integer
by of
ByName String
name -> forall a. HasCallStack => a
undefined
ByIndex Integer
index ->
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.argmax_tlb
Tensor gradient layout device dataType shape
input
(forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
Bool
True
type MaxAllCheckF :: Shape [Dim (Name Symbol) (Size Nat)] -> Constraint
type MaxAllCheckF shape = AllDimsPositiveF "maxAll" shape
maxAll ::
forall gradient layout device dataType shape.
MaxAllCheckF shape =>
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType ('Shape '[])
maxAll :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MaxAllCheckF shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType ('Shape '[])
maxAll = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.max_t