{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Torch.GraduallyTyped.Tensor.MathOperations.Comparison where import Control.Monad.Catch (MonadThrow) import Data.Singletons (SingI (..), SingKind (..)) import GHC.Generics (Generic) import GHC.TypeLits (Nat, Symbol) import System.IO.Unsafe (unsafePerformIO) import Torch.GraduallyTyped.DType (DType (..), DataType (..)) import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked) import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..)) import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF, GetDimImplF) import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SSelectDim, SelectDim (..), Shape (..), Size (..)) import Torch.GraduallyTyped.Tensor.Type (Tensor) import Torch.GraduallyTyped.Unify (type (<+>)) import Torch.Internal.Cast (cast2, cast3) import Torch.Internal.GC (unsafeThrowableIO) import qualified Torch.Internal.Managed.Native as ATen import qualified Torch.Internal.Managed.Type.Tuple as ATen () import Type.Errors.Pretty (TypeError, type (%), type (<>)) gt, lt, ge, le, eq, ne, (>.), (<.), (>=.), (<=.), (==.), (/=.) :: forall gradient layout device dataType shape gradient' layout' device' dataType' shape' shape'' m. (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m ( Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'' ) Tensor gradient layout device dataType shape a gt :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') `gt` Tensor gradient' layout' device' dataType' shape' b = forall a (m :: * -> *). MonadThrow m => IO a -> m a unsafeThrowableIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor) ATen.gt_tt Tensor gradient layout device dataType shape a Tensor gradient' layout' device' dataType' shape' b Tensor gradient layout device dataType shape a lt :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') `lt` Tensor gradient' layout' device' dataType' shape' b = forall a (m :: * -> *). MonadThrow m => IO a -> m a unsafeThrowableIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor) ATen.lt_tt Tensor gradient layout device dataType shape a Tensor gradient' layout' device' dataType' shape' b Tensor gradient layout device dataType shape a ge :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') `ge` Tensor gradient' layout' device' dataType' shape' b = forall a (m :: * -> *). MonadThrow m => IO a -> m a unsafeThrowableIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor) ATen.ge_tt Tensor gradient layout device dataType shape a Tensor gradient' layout' device' dataType' shape' b Tensor gradient layout device dataType shape a le :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') `le` Tensor gradient' layout' device' dataType' shape' b = forall a (m :: * -> *). MonadThrow m => IO a -> m a unsafeThrowableIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor) ATen.le_tt Tensor gradient layout device dataType shape a Tensor gradient' layout' device' dataType' shape' b Tensor gradient layout device dataType shape a eq :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') `eq` Tensor gradient' layout' device' dataType' shape' b = forall a (m :: * -> *). MonadThrow m => IO a -> m a unsafeThrowableIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor) ATen.eq_tt Tensor gradient layout device dataType shape a Tensor gradient' layout' device' dataType' shape' b Tensor gradient layout device dataType shape a ne :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') `ne` Tensor gradient' layout' device' dataType' shape' b = forall a (m :: * -> *). MonadThrow m => IO a -> m a unsafeThrowableIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor) ATen.ne_tt Tensor gradient layout device dataType shape a Tensor gradient' layout' device' dataType' shape' b >. :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') (>.) = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') gt <. :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') (<.) = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') lt >=. :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') (>=.) = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') ge <=. :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') (<=.) = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') le ==. :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') (==.) = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') eq /=. :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') (/=.) = forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') ne data Order = Ascending | Descending deriving stock (Int -> Order -> ShowS [Order] -> ShowS Order -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [Order] -> ShowS $cshowList :: [Order] -> ShowS show :: Order -> String $cshow :: Order -> String showsPrec :: Int -> Order -> ShowS $cshowsPrec :: Int -> Order -> ShowS Show, Order -> Order -> Bool forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a /= :: Order -> Order -> Bool $c/= :: Order -> Order -> Bool == :: Order -> Order -> Bool $c== :: Order -> Order -> Bool Eq, Eq Order Order -> Order -> Bool Order -> Order -> Ordering Order -> Order -> Order forall a. Eq a -> (a -> a -> Ordering) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> a) -> (a -> a -> a) -> Ord a min :: Order -> Order -> Order $cmin :: Order -> Order -> Order max :: Order -> Order -> Order $cmax :: Order -> Order -> Order >= :: Order -> Order -> Bool $c>= :: Order -> Order -> Bool > :: Order -> Order -> Bool $c> :: Order -> Order -> Bool <= :: Order -> Order -> Bool $c<= :: Order -> Order -> Bool < :: Order -> Order -> Bool $c< :: Order -> Order -> Bool compare :: Order -> Order -> Ordering $ccompare :: Order -> Order -> Ordering Ord, forall x. Rep Order x -> Order forall x. Order -> Rep Order x forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a $cto :: forall x. Rep Order x -> Order $cfrom :: forall x. Order -> Rep Order x Generic) data Sorted gradient layout device dataType shape where Sorted :: forall gradient layout device dataType shape. { forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). Sorted gradient layout device dataType shape -> Tensor gradient layout device dataType shape sorted :: Tensor gradient layout device dataType shape, forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). Sorted gradient layout device dataType shape -> Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape indices :: Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape } -> Sorted gradient layout device dataType shape deriving stock (Int -> Sorted gradient layout device dataType shape -> ShowS forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). Int -> Sorted gradient layout device dataType shape -> ShowS forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). [Sorted gradient layout device dataType shape] -> ShowS forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). Sorted gradient layout device dataType shape -> String showList :: [Sorted gradient layout device dataType shape] -> ShowS $cshowList :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). [Sorted gradient layout device dataType shape] -> ShowS show :: Sorted gradient layout device dataType shape -> String $cshow :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). Sorted gradient layout device dataType shape -> String showsPrec :: Int -> Sorted gradient layout device dataType shape -> ShowS $cshowsPrec :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). Int -> Sorted gradient layout device dataType shape -> ShowS Show, forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x. Rep (Sorted gradient layout device dataType shape) x -> Sorted gradient layout device dataType shape forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x. Sorted gradient layout device dataType shape -> Rep (Sorted gradient layout device dataType shape) x $cto :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x. Rep (Sorted gradient layout device dataType shape) x -> Sorted gradient layout device dataType shape $cfrom :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x. Sorted gradient layout device dataType shape -> Rep (Sorted gradient layout device dataType shape) x Generic) type SortErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) = "Cannot apply sort on the dimension matching" % "" % " '" <> by <> "'" % "" % "in the shape" % "" % " '" <> dims <> "'." % "" type family SortCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe (Dim (Name Symbol) (Size Nat))) :: [Dim (Name Symbol) (Size Nat)] where SortCheckF by dims 'Nothing = TypeError (SortErrorMessage by dims) SortCheckF _ dims ('Just _) = dims type family SortF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where SortF 'UncheckedSelectDim _ = 'UncheckedShape SortF _ 'UncheckedShape = 'UncheckedShape SortF ('SelectDim by) ('Shape dims) = 'Shape (SortCheckF by dims (GetDimImplF by dims)) sSort :: forall selectDim gradient layout device dataType shape. SSelectDim selectDim -> Order -> Tensor gradient layout device dataType shape -> Sorted gradient layout device dataType (SortF selectDim shape) sSort :: forall (selectDim :: SelectDim (By Symbol Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SSelectDim selectDim -> Order -> Tensor gradient layout device dataType shape -> Sorted gradient layout device dataType (SortF selectDim shape) sSort SSelectDim selectDim by Order order Tensor gradient layout device dataType shape tensor = let by' :: By String Integer by' = forall a. IsChecked a -> a forgetIsChecked forall a b. (a -> b) -> a -> b $ forall k (a :: k). SingKind k => Sing a -> Demote k fromSing SSelectDim selectDim by in forall a b c. (a -> b -> c) -> (a, b) -> c uncurry forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). Tensor gradient layout device dataType shape -> Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape -> Sorted gradient layout device dataType shape Sorted forall a b. (a -> b) -> a -> b $ case By String Integer by' of ByName String name -> forall a. IO a -> a unsafePerformIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 x2 cx2 y cy. (Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) => (ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y cast3 ForeignPtr Tensor -> ForeignPtr Dimname -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor))) ATen.sort_tnb Tensor gradient layout device dataType shape tensor String name (Order order forall a. Eq a => a -> a -> Bool == Order Descending) ByIndex Integer index -> forall a. IO a -> a unsafePerformIO forall a b. (a -> b) -> a -> b $ forall a ca x1 cx1 x2 cx2 y cy. (Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) => (ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor))) ATen.sort_tlb Tensor gradient layout device dataType shape tensor (forall a. Num a => Integer -> a fromInteger Integer index :: Int) (Order order forall a. Eq a => a -> a -> Bool == Order Descending) sort :: forall selectDim gradient layout device dataType shape. SingI selectDim => Order -> Tensor gradient layout device dataType shape -> Sorted gradient layout device dataType (SortF selectDim shape) sort :: forall (selectDim :: SelectDim (By Symbol Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SingI selectDim => Order -> Tensor gradient layout device dataType shape -> Sorted gradient layout device dataType (SortF selectDim shape) sort = forall (selectDim :: SelectDim (By Symbol Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SSelectDim selectDim -> Order -> Tensor gradient layout device dataType shape -> Sorted gradient layout device dataType (SortF selectDim shape) sSort (forall {k} (a :: k). SingI a => Sing a sing @selectDim)