{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.Tensor.MathOperations.Comparison where

import Control.Monad.Catch (MonadThrow)
import Data.Singletons (SingI (..), SingKind (..))
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF, GetDimImplF)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SSelectDim, SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.Internal.Cast (cast2, cast3)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Tuple as ATen ()
import Type.Errors.Pretty (TypeError, type (%), type (<>))

gt,
  lt,
  ge,
  le,
  eq,
  ne,
  (>.),
  (<.),
  (>=.),
  (<=.),
  (==.),
  (/=.) ::
    forall gradient layout device dataType shape gradient' layout' device' dataType' shape' shape'' m.
    (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
    Tensor gradient layout device dataType shape ->
    Tensor gradient' layout' device' dataType' shape' ->
    m
      ( Tensor
          ('Gradient 'WithoutGradient)
          (layout <+> layout')
          (device <+> device')
          ('DataType 'Bool)
          shape''
      )
Tensor gradient layout device dataType shape
a gt :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
`gt` Tensor gradient' layout' device' dataType' shape'
b = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.gt_tt Tensor gradient layout device dataType shape
a Tensor gradient' layout' device' dataType' shape'
b
Tensor gradient layout device dataType shape
a lt :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
`lt` Tensor gradient' layout' device' dataType' shape'
b = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.lt_tt Tensor gradient layout device dataType shape
a Tensor gradient' layout' device' dataType' shape'
b
Tensor gradient layout device dataType shape
a ge :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
`ge` Tensor gradient' layout' device' dataType' shape'
b = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.ge_tt Tensor gradient layout device dataType shape
a Tensor gradient' layout' device' dataType' shape'
b
Tensor gradient layout device dataType shape
a le :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
`le` Tensor gradient' layout' device' dataType' shape'
b = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.le_tt Tensor gradient layout device dataType shape
a Tensor gradient' layout' device' dataType' shape'
b
Tensor gradient layout device dataType shape
a eq :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
`eq` Tensor gradient' layout' device' dataType' shape'
b = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.eq_tt Tensor gradient layout device dataType shape
a Tensor gradient' layout' device' dataType' shape'
b
Tensor gradient layout device dataType shape
a ne :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
`ne` Tensor gradient' layout' device' dataType' shape'
b = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.ne_tt Tensor gradient layout device dataType shape
a Tensor gradient' layout' device' dataType' shape'
b
>. :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
(>.) = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
gt
<. :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
(<.) = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
lt
>=. :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
(>=.) = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
ge
<=. :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
(<=.) = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
le
==. :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
(==.) = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
eq
/=. :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
(/=.) = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
ne

data Order = Ascending | Descending
  deriving stock (Int -> Order -> ShowS
[Order] -> ShowS
Order -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Order] -> ShowS
$cshowList :: [Order] -> ShowS
show :: Order -> String
$cshow :: Order -> String
showsPrec :: Int -> Order -> ShowS
$cshowsPrec :: Int -> Order -> ShowS
Show, Order -> Order -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Order -> Order -> Bool
$c/= :: Order -> Order -> Bool
== :: Order -> Order -> Bool
$c== :: Order -> Order -> Bool
Eq, Eq Order
Order -> Order -> Bool
Order -> Order -> Ordering
Order -> Order -> Order
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Order -> Order -> Order
$cmin :: Order -> Order -> Order
max :: Order -> Order -> Order
$cmax :: Order -> Order -> Order
>= :: Order -> Order -> Bool
$c>= :: Order -> Order -> Bool
> :: Order -> Order -> Bool
$c> :: Order -> Order -> Bool
<= :: Order -> Order -> Bool
$c<= :: Order -> Order -> Bool
< :: Order -> Order -> Bool
$c< :: Order -> Order -> Bool
compare :: Order -> Order -> Ordering
$ccompare :: Order -> Order -> Ordering
Ord, forall x. Rep Order x -> Order
forall x. Order -> Rep Order x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Order x -> Order
$cfrom :: forall x. Order -> Rep Order x
Generic)

data Sorted gradient layout device dataType shape where
  Sorted ::
    forall gradient layout device dataType shape.
    { forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Sorted gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
sorted :: Tensor gradient layout device dataType shape,
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Sorted gradient layout device dataType shape
-> Tensor
     ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape
indices :: Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape
    } ->
    Sorted gradient layout device dataType shape
  deriving stock (Int -> Sorted gradient layout device dataType shape -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int -> Sorted gradient layout device dataType shape -> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
[Sorted gradient layout device dataType shape] -> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Sorted gradient layout device dataType shape -> String
showList :: [Sorted gradient layout device dataType shape] -> ShowS
$cshowList :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
[Sorted gradient layout device dataType shape] -> ShowS
show :: Sorted gradient layout device dataType shape -> String
$cshow :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Sorted gradient layout device dataType shape -> String
showsPrec :: Int -> Sorted gradient layout device dataType shape -> ShowS
$cshowsPrec :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int -> Sorted gradient layout device dataType shape -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep (Sorted gradient layout device dataType shape) x
-> Sorted gradient layout device dataType shape
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Sorted gradient layout device dataType shape
-> Rep (Sorted gradient layout device dataType shape) x
$cto :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep (Sorted gradient layout device dataType shape) x
-> Sorted gradient layout device dataType shape
$cfrom :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Sorted gradient layout device dataType shape
-> Rep (Sorted gradient layout device dataType shape) x
Generic)

type SortErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
  "Cannot apply sort on the dimension matching"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "in the shape"
    % ""
    % "    '" <> dims <> "'."
    % ""

type family SortCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe (Dim (Name Symbol) (Size Nat))) :: [Dim (Name Symbol) (Size Nat)] where
  SortCheckF by dims 'Nothing = TypeError (SortErrorMessage by dims)
  SortCheckF _ dims ('Just _) = dims

type family SortF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  SortF 'UncheckedSelectDim _ = 'UncheckedShape
  SortF _ 'UncheckedShape = 'UncheckedShape
  SortF ('SelectDim by) ('Shape dims) = 'Shape (SortCheckF by dims (GetDimImplF by dims))

sSort ::
  forall selectDim gradient layout device dataType shape.
  SSelectDim selectDim ->
  Order ->
  Tensor gradient layout device dataType shape ->
  Sorted gradient layout device dataType (SortF selectDim shape)
sSort :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SSelectDim selectDim
-> Order
-> Tensor gradient layout device dataType shape
-> Sorted gradient layout device dataType (SortF selectDim shape)
sSort SSelectDim selectDim
by Order
order Tensor gradient layout device dataType shape
tensor =
  let by' :: By String Integer
by' = forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSelectDim selectDim
by
   in forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor
     ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape
-> Sorted gradient layout device dataType shape
Sorted forall a b. (a -> b) -> a -> b
$ case By String Integer
by' of
        ByName String
name -> forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Dimname
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.sort_tnb Tensor gradient layout device dataType shape
tensor String
name (Order
order forall a. Eq a => a -> a -> Bool
== Order
Descending)
        ByIndex Integer
index -> forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.sort_tlb Tensor gradient layout device dataType shape
tensor (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int) (Order
order forall a. Eq a => a -> a -> Bool
== Order
Descending)

sort ::
  forall selectDim gradient layout device dataType shape.
  SingI selectDim =>
  Order ->
  Tensor gradient layout device dataType shape ->
  Sorted gradient layout device dataType (SortF selectDim shape)
sort :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SingI selectDim =>
Order
-> Tensor gradient layout device dataType shape
-> Sorted gradient layout device dataType (SortF selectDim shape)
sort = forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SSelectDim selectDim
-> Order
-> Tensor gradient layout device dataType shape
-> Sorted gradient layout device dataType (SortF selectDim shape)
sSort (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)