{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Functional where

import Control.Arrow ((&&&))
import Data.Finite
import qualified Data.Int as I
import Data.Kind
  ( Constraint,
    Type,
  )
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.ForeignPtr
import GHC.Generics (Generic)
import GHC.Natural (Natural)
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.Functional
  ( Reduction (..),
    Tri (..),
    isUpper,
    kOne,
  )
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen.Managed
import qualified Torch.Internal.Managed.Native as ATen.Managed
import qualified Torch.Internal.Managed.Type.Scalar as ATen.Managed
import qualified Torch.Internal.Managed.Type.Tensor as ATen.Managed
import qualified Torch.Internal.Managed.Type.Tuple as ATen.Managed
import qualified Torch.Internal.Type as ATen
import qualified Torch.Scalar as D
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import qualified Torch.TensorOptions as D
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Tensor
import Prelude hiding
  ( abs,
    acos,
    acosh,
    all,
    any,
    asin,
    asinh,
    atan,
    atanh,
    ceil,
    cos,
    cosh,
    exp,
    floor,
    isNaN,
    log,
    max,
    min,
    round,
    sin,
    sinh,
    tan,
    tanh,
  )

-- $setup
--
-- >>> :set -XOverloadedLists

-- | Computes the bitwise NOT of the given input tensor.
-- The input tensor must be of integral or Boolean types.
-- For bool tensors, it computes the logical NOT.
--
-- >>> dtype &&& shape $ bitwiseNot (ones :: CPUTensor 'D.Bool [3,3])
-- (Bool,[3,3])
bitwiseNot ::
  forall device shape.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape
bitwiseNot :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape -> Tensor device 'Bool shape
bitwiseNot Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.bitwise_not_t Tensor device 'Bool shape
input

-- | Computes the element-wise logical NOT of the given input tensor.
-- If not specified, the output tensor will have the bool dtype.
-- If the input tensor is not a bool tensor, zeros are treated as False and non-zeros are treated as True.
logicalNot ::
  forall device shape.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape
logicalNot :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalNot Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_not_t Tensor device 'Bool shape
input

logicalXor ::
  forall device shape.
  -- | self
  Tensor device 'D.Bool shape ->
  -- | other
  Tensor device 'D.Bool shape ->
  Tensor device 'D.Bool shape
logicalXor :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalXor Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_xor_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other

logicalAnd ::
  forall device shape.
  -- | self
  Tensor device 'D.Bool shape ->
  -- | other
  Tensor device 'D.Bool shape ->
  Tensor device 'D.Bool shape
logicalAnd :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalAnd Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_and_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other

logicalOr ::
  forall device shape.
  -- | self
  Tensor device 'D.Bool shape ->
  -- | other
  Tensor device 'D.Bool shape ->
  Tensor device 'D.Bool shape
logicalOr :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalOr Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_or_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other

type family SumDType (dtype :: D.DType) :: D.DType where
  SumDType D.Bool = D.Int64
  SumDType D.UInt8 = D.Int64
  SumDType D.Int8 = D.Int64
  SumDType D.Int16 = D.Int64
  SumDType D.Int32 = D.Int64
  SumDType D.Int64 = D.Int64
  SumDType D.Half = D.Half
  SumDType D.Float = D.Float
  SumDType D.Double = D.Double

type family SumDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SumDTypeIsValid '( 'D.CPU, 0) dtype = DTypeIsNotHalf '( 'D.CPU, 0) dtype
  SumDTypeIsValid '( 'D.CUDA, _) dtype = ()
  SumDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | sumAll
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Bool '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.UInt8 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int8 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int16 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int32 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int64 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Float) $ sumAll (ones :: CPUTensor 'D.Float '[2, 3])
-- (Float,([],6.0))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Double) $ sumAll (ones :: CPUTensor 'D.Double '[2, 3])
-- (Double,([],6.0))
sumAll ::
  forall shape dtype' dtype device.
  ( SumDTypeIsValid device dtype,
    dtype' ~ SumDType dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype' '[]
sumAll :: forall (shape :: [Nat]) (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' '[]
sumAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sum_t Tensor device dtype shape
input

-- | sumDim
--
-- >>> dtype &&& shape $ sumDim @0 (ones :: CPUTensor 'D.Float '[3,4,5])
-- (Float,[4,5])
-- >>> sumDim @1 (ones :: CPUTensor 'D.Float '[2,4])
-- Tensor Float [2] [ 4.0000   ,  4.0000   ]
sumDim ::
  forall d shape shape' dtype dtype' device.
  ( KnownNat d,
    shape' ~ DropValue shape d,
    SumDTypeIsValid device dtype,
    dtype' ~ SumDType dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype' shape'
sumDim :: forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
sumDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.sum_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @d)

-- | abs
--
-- >>> dtype &&& shape $ abs (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
abs ::
  forall shape dtype device.
  (StandardDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
abs :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
abs Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.abs_t Tensor device dtype shape
input

-- | ceil
--
-- >>> dtype &&& shape $ ceil (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
ceil ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
ceil :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
ceil Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.ceil_t Tensor device dtype shape
input

-- | floor
--
-- >>> dtype &&& shape $ floor (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
floor ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
floor :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
floor Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.floor_t Tensor device dtype shape
input

type family MinMaxDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  MinMaxDTypeIsValid '( 'D.CPU, 0) dtype = DTypeIsNotHalf '( 'D.CPU, 0) dtype
  MinMaxDTypeIsValid '( 'D.CUDA, _) dtype = ()
  MinMaxDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | min
--
-- >>> dtype &&& shape $ min (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
min ::
  forall shape dtype device.
  ( MinMaxDTypeIsValid device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
min :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(MinMaxDTypeIsValid device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
min Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.min_t Tensor device dtype shape
input

-- | max
--
-- >>> dtype &&& shape $ max (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
max ::
  forall shape dtype device.
  ( MinMaxDTypeIsValid device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
max :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(MinMaxDTypeIsValid device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.max_t Tensor device dtype shape
input

type family MeanDTypeValidation (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  MeanDTypeValidation '(deviceType, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '(deviceType, deviceIndex) dtype,
      DTypeIsNotHalf '(deviceType, deviceIndex) dtype
    )

-- | Computes the mean while carrying out a full reduction of all tensor dimensions.
--
-- >>> meanAll (ones :: CPUTensor 'D.Float '[])
-- Tensor Float []  1.0000
-- >>> meanAll (zeros :: CPUTensor 'D.Float '[2,2])
-- Tensor Float []  0.0000
meanAll ::
  forall shape dtype device.
  ( MeanDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
meanAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
meanAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mean_t Tensor device dtype shape
input

-- | Computes the mean while carrying out a full reduction of all tensor dimensions.
-- This version is not restricted and can return NaN.
--
-- >>> unsafeMeanAll (ones :: CPUTensor 'D.Float '[])
-- Tensor Float []  1.0000
-- >>> unsafeMeanAll (ones :: CPUTensor 'D.Float '[0])
-- Tensor Float [] NaN
-- >>> unsafeMeanAll (zeros :: CPUTensor 'D.Float '[2,2])
-- Tensor Float []  0.0000
unsafeMeanAll ::
  forall shape dtype device.
  MeanDTypeValidation device dtype =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
unsafeMeanAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
MeanDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype '[]
unsafeMeanAll
  Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mean_t Tensor device dtype shape
input

-- | Computes the mean and reduces the tensor over the specified dimension.
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ meanDim @0 t
-- (Float,[4,5])
-- >>> dtype &&& shape $ meanDim @1 t
-- (Float,[3,5])
-- >>> dtype &&& shape $ meanDim @2 t
-- (Float,[3,4])
meanDim ::
  forall dim shape' shape dtype device.
  ( KnownNat dim,
    shape' ~ DropValue shape dim,
    MeanDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
meanDim :: forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
 MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
meanDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)

-- | Computes the mean and reduces the tensor over the specified dimension.
--
-- >>> import Torch.Typed.Factories
-- >>> import Data.Default.Class
-- >>> t = def :: NamedTensor '( D.CPU, 0) 'D.Float '[Vector 3, Vector 4, Vector 5]
-- >>> dtype &&& shape $ meanNamedDim @(Vector 4) t
-- (Float,[3,5])
meanNamedDim ::
  forall dim shape' shape dtype device.
  ( KnownNat (FindDim dim shape),
    shape' ~ DropNamedValue shape dim,
    MeanDTypeValidation device dtype
  ) =>
  -- | input
  NamedTensor device dtype shape ->
  -- | output
  NamedTensor device dtype shape'
meanNamedDim :: forall (dim :: Size) (shape' :: Shape) (shape :: Shape)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat (FindDim dim shape), shape' ~ DropNamedValue shape dim,
 MeanDTypeValidation device dtype) =>
NamedTensor device dtype shape -> NamedTensor device dtype shape'
meanNamedDim NamedTensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tl NamedTensor device dtype shape
input Int
_dim
  where
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @(FindDim dim shape)

-- | Computes the mean and optionally reduces the tensor over the specified dimension.
--
-- See https://pytorch.org/docs/stable/torch.html#torch.mean for more information.
--
-- >>> t = fromJust [[5, 1], [3, 2], [4, 1], [2, 7]] :: CPUTensor 'D.Float '[4, 2]
-- >>> mean @0 @KeepDim t
-- Tensor Float [1,2] [[ 3.5000   ,  2.7500   ]]
mean ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    MeanDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape'
mean :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
mean Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | Computes the median while carrying out a full reduction of all tensor dimensions.
--
-- >>> dtype &&& shape $ medianAll (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
medianAll ::
  forall shape dtype device.
  ( StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
medianAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
medianAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.median_t Tensor device dtype shape
input

-- | Computes the median and reduces the tensor over the specified dimension.
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ fst $ medianDim @0 t
-- (Float,[4,5])
-- >>> dtype &&& shape $ fst $ medianDim @1 t
-- (Float,[3,5])
-- >>> dtype &&& shape $ fst $ medianDim @2 t
-- (Float,[3,4])
medianDim ::
  forall dim shape' shape dtype device.
  ( KnownNat dim,
    shape' ~ DropValue shape dim,
    StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  ( Tensor device dtype shape',
    Tensor device 'D.Int64 shape'
  )
medianDim :: forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
 StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
medianDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.median_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)

-- | Computes the median and optionally reduces the tensor over the specified dimension.
--
-- See https://pytorch.org/docs/stable/torch.html#torch.median for more information.
--
-- >>> t = fromJust [[5, 1], [3, 2], [4, 1], [2, 7]] :: CPUTensor 'D.Float '[4, 2]
--
-- -- libtorch 1.7.0
-- -- (Tensor Float [1,2] [[ 3.0000   ,  1.0000   ]],Tensor Int64 [1,2] [[ 1,  0]])
-- -- libtorch 1.8.0
-- >>> median @0 @KeepDim t
-- (Tensor Float [1,2] [[ 3.0000   ,  1.0000   ]],Tensor Int64 [1,2] [[ 1,  2]])
median ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  (Tensor device dtype shape', Tensor device 'D.Int64 shape')
median :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
median Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.median_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | Returns a tuple '(modes, indices)' where 'modes' is the mode value of each row of the 'input' tensor
-- in the given dimension 'dim', i.e. a value which appears most often in that row,
-- and 'indices' is the index location of each mode value found.
--
-- See https://pytorch.org/docs/stable/torch.html#torch.mode for more information.
--
-- >>> t = fromJust [[0, 5], [0, 2], [3, 5]] :: CPUTensor 'D.Int64 '[3, 2]
--
-- >>> (modes :: CPUTensor 'D.Int64 '[2], indices :: CPUTensor 'D.Int64 '[2]) = mode @0 @DropDim t
-- >>> (dtype modes, shape modes, D.asValue (toDynamic modes) :: [Int])
-- (Int64,[2],[0,5])
-- >>> (dtype indices, shape indices, D.asValue (toDynamic indices) :: [Int])
-- (Int64,[2],[1,2])
--
-- >>> t = fromJust [[0, 0], [0, 1], [3, 3]] :: CPUTensor 'D.Float '[3, 2]
--
-- >>> (modes :: CPUTensor 'D.Float '[3,1], indices :: CPUTensor 'D.Int64 '[3,1]) = mode @1 @KeepDim t
-- >>> (dtype modes, shape modes, D.asValue (toDynamic modes) :: [[Float]])
-- (Float,[3,1],[[0.0],[0.0],[3.0]])
-- >>> (dtype indices, shape indices, D.asValue (toDynamic indices) :: [[Int]])
-- (Int64,[3,1],[[1],[0],[1]])
mode ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  (Tensor device dtype shape', Tensor device 'D.Int64 shape')
mode :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
mode Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.mode_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | addScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ addScalar 1 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
addScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
addScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
addScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.add_ts Tensor device dtype shape
input a
a

-- | subScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ subScalar 1 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
subScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
subScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
subScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.sub_ts Tensor device dtype shape
input a
a

-- | mulScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ mulScalar 2 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
mulScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
mulScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.mul_ts Tensor device dtype shape
input a
a

-- | divScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ divScalar 2 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
divScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
divScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.div_ts Tensor device dtype shape
input a
a

-- | powScalar
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ powScalar 2 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
powScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | power
  a ->
  -- | input tensor
  Tensor device dtype shape ->
  -- | output tensor
  Tensor device dtype shape
powScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
powScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.pow_ts Tensor device dtype shape
input a
a

-- | erf
--
-- >>> dtype &&& shape $ erf (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
erf ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
erf :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
erf Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.erf_t Tensor device dtype shape
input

-- | exp
--
-- >>> dtype &&& shape $ exp (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
exp ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
exp :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
exp Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.exp_t Tensor device dtype shape
input

-- | log1p
--
-- >>> dtype &&& shape $ log1p (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
log1p ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
log1p :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log1p Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log1p_t Tensor device dtype shape
input

-- | log2
-- >>> dtype &&& shape $ log2 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
log2 ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
log2 :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log2_t Tensor device dtype shape
input

-- | log10
--
-- >>> dtype &&& shape $ log10 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
log10 ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
log10 :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log10 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log10_t Tensor device dtype shape
input

-- | pow
-- this operation supports broadcasting
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ pow (2 :: CPUTensor 'D.Float '[]) (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
pow ::
  forall shape'' shape shape' dtype device.
  ( BasicArithmeticDTypeIsValid device dtype,
    shape'' ~ Broadcast shape shape'
  ) =>
  -- | power
  Tensor device dtype shape ->
  -- | input tensor
  Tensor device dtype shape' ->
  -- | output tensor
  Tensor device dtype shape''
pow :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 shape'' ~ Broadcast shape shape') =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
pow Tensor device dtype shape
exponent Tensor device dtype shape'
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.pow_tt Tensor device dtype shape'
input Tensor device dtype shape
exponent

-- | relu activation function
--
-- >>> dtype &&& shape $ relu (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
relu ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
relu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
relu t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.relu_t (forall a. a -> Wrap a
Wrap t
input)

-- | selu
--
-- >>> dtype &&& shape $ selu (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
selu ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
selu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
selu Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.selu_t Tensor device dtype shape
input

-- | mish
-- `mish` is a smooth activation function, see https://arxiv.org/abs/1908.08681 for details.
--
-- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[Float]]) $ mish (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,([3,2],[[0.86509836,0.86509836],[0.86509836,0.86509836],[0.86509836,0.86509836]]))
mish ::
  forall shape dtype device.
  ( StandardFloatingPointDTypeValidation device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    shape ~ Broadcast shape shape
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape
mish :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(StandardFloatingPointDTypeValidation device dtype,
 BasicArithmeticDTypeIsValid device dtype,
 shape ~ Broadcast shape shape) =>
Tensor device dtype shape -> Tensor device dtype shape
mish = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul forall (m :: Size) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tanh forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
softplus (Float
1 :: Float) Float
20

-- | sigmoid
--
-- >>> dtype &&& shape $ sigmoid (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sigmoid ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
sigmoid :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
sigmoid Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sigmoid_t Tensor device dtype shape
input

-- | sin
--
-- >>> dtype &&& shape $ sin (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sin ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
sin :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sin t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sin_t (forall a. a -> Wrap a
Wrap t
input)

-- | sinh
--
-- >>> dtype &&& shape $ sinh (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sinh ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
sinh :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sinh t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sinh_t (forall a. a -> Wrap a
Wrap t
input)

-- | cos
--
-- >>> dtype &&& shape $ cos (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
cos ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
cos :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
cos t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.cos_t (forall a. a -> Wrap a
Wrap t
input)

-- | sqrt
sqrt ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
sqrt :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sqrt t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sqrt_t (forall a. a -> Wrap a
Wrap t
input)

-- | tanh
tanh ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
tanh :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tanh Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.tanh_t Tensor device dtype shape
input

-- | ConditionalReduction
--
-- >>> :kind! ConditionalReduction '[3,2] ReduceNone
-- ConditionalReduction '[3,2] ReduceNone :: [Natural]
-- = '[3, 2]
-- >>> :kind! ConditionalReduction '[3,2] ReduceMean
-- ConditionalReduction '[3,2] ReduceMean :: [Natural]
-- = '[]
type family ConditionalReduction (shape :: [Nat]) (reduction :: Reduction) :: [Nat] where
  ConditionalReduction shape ReduceNone = shape
  ConditionalReduction shape _ = '[]

class KnownReduction reduction where
  reductionVal :: Int

instance KnownReduction ReduceNone where
  reductionVal :: Int
reductionVal = Int
0

instance KnownReduction ReduceMean where
  reductionVal :: Int
reductionVal = Int
1

instance KnownReduction ReduceSum where
  reductionVal :: Int
reductionVal = Int
2

-- | binary cross entropy
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ binaryCrossEntropy @ReduceNone t t t
-- (Float,[2,2])
-- >>> dtype &&& shape $ binaryCrossEntropy @ReduceMean t t t
-- (Float,[])
-- >>> dtype &&& shape $ binaryCrossEntropy @ReduceSum t t t
-- (Float,[])
binaryCrossEntropy ::
  forall (reduction :: Reduction) shape shape' dtype device.
  ( KnownReduction reduction,
    shape' ~ ConditionalReduction shape reduction,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | weight
  Tensor device dtype shape ->
  -- | prediction
  Tensor device dtype shape ->
  -- | target
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
binaryCrossEntropy :: forall (reduction :: Reduction) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction,
 shape' ~ ConditionalReduction shape reduction,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype shape'
binaryCrossEntropy Tensor device dtype shape
weight Tensor device dtype shape
prediction Tensor device dtype shape
target =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.binary_cross_entropy_tttl
      Tensor device dtype shape
prediction
      Tensor device dtype shape
target
      Tensor device dtype shape
weight
      (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)

-- | mseLoss
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ mseLoss @ReduceNone t t
-- (Float,[2,2])
-- >>> dtype &&& shape $ mseLoss @ReduceMean t t
-- (Float,[])
-- >>> dtype &&& shape $ mseLoss @ReduceSum t t
-- (Float,[])
mseLoss ::
  forall (reduction :: Reduction) shape shape' dtype device.
  ( KnownReduction reduction,
    shape' ~ ConditionalReduction shape reduction,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | prediction
  Tensor device dtype shape ->
  -- | target
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
mseLoss :: forall (reduction :: Reduction) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction,
 shape' ~ ConditionalReduction shape reduction,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape'
mseLoss Tensor device dtype shape
prediction Tensor device dtype shape
target =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.mse_loss_ttl
      Tensor device dtype shape
prediction
      Tensor device dtype shape
target
      (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)

-- | softmax
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ softmax @0 t
-- (Float,[2,2])
-- >>> dtype &&& shape $ softmax @1 t
-- (Float,[2,2])
softmax ::
  forall dim shape dtype device.
  ( KnownNat dim,
    DimOutOfBoundCheck shape dim,
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
softmax :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
softmax Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.Managed.softmax_tls Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype)

-- | logSoftmax
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ logSoftmax @0 t
-- (Float,[2,2])
-- >>> dtype &&& shape $ logSoftmax @1 t
-- (Float,[2,2])
logSoftmax ::
  forall dim shape dtype device.
  ( KnownNat dim,
    DimOutOfBoundCheck shape dim,
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
logSoftmax :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
logSoftmax Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.Managed.log_softmax_tls Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype)

type family Square (shape :: [Nat]) :: [Nat] where
  Square (n : n : '[]) = '[n, n]
  Square (b : n : n : '[]) = '[b, n, n]
  Square _ = TypeError (Text "This shape must be square matrix or batch + square matrix.")

type family VectorOfSquare (shape :: [Nat]) :: [Nat] where
  VectorOfSquare (n : n : '[]) = '[n]
  VectorOfSquare (b : n : n : '[]) = '[b, n]
  VectorOfSquare _ = TypeError (Text "This shape must be square matrix or batch + square matrix.")

type family FstSquareDim (shape :: [Nat]) :: Nat where
  FstSquareDim (n : m : '[]) = n
  FstSquareDim (b : n : m : '[]) = n
  FstSquareDim _ = TypeError (Text "Can not get first dimention of matrix or batch + matrix.")

type family InverseShapeIsValid (device :: (D.DeviceType, Nat)) (shape :: [Nat]) :: Constraint where
  InverseShapeIsValid '( 'D.CPU, 0) _ = ()
  InverseShapeIsValid '( 'D.CUDA, _) shape = AllDimsPositive shape

type family InverseDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  InverseDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  InverseDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  InverseDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | inverse
-- TODO: if rank < n for any tensors in the batch, then this will not work. we can't decide this statically, but we should prevent runtime errors. therefore, return Maybe?
--
-- >>> t <- randn :: IO (CPUTensor 'D.Float '[3,2,2])
-- >>> dtype &&& shape $ inverse t
-- (Float,[3,2,2])
-- >>> t <- randn :: IO (CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ inverse t
-- (Float,[2,2])
inverse ::
  forall shape shape' dtype device.
  ( shape' ~ Square shape,
    InverseShapeIsValid device shape,
    InverseDTypeIsValid device dtype
  ) =>
  -- | inverse
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
inverse :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ Square shape, InverseShapeIsValid device shape,
 InverseDTypeIsValid device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape'
inverse Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.inverse_t Tensor device dtype shape
input

type family SymeigDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SymeigDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  SymeigDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  SymeigDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | symeig
-- Warning:
-- torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future PyTorch release.
-- The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
-- L, _ = torch.symeig(A, upper=upper)
-- should be replaced with
-- L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
-- and
-- L, V = torch.symeig(A, eigenvectors=True)
-- should be replaced with
-- L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,2,2])
-- >>> (eigenVals,eigenVecs) = symeig Upper t
-- >>> dtype &&& shape $ eigenVals -- Skip warning
-- ...
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> :t eigenVals
-- eigenVals :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[3,2,2])
-- >>> :t eigenVecs
-- eigenVecs :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2, 2]
-- >>> (eigenVals,eigenVecs) = symeig Lower t
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[3,2,2])
symeig ::
  forall shape shape' shape'' dtype device.
  ( shape' ~ VectorOfSquare shape,
    shape'' ~ Square shape,
    SymeigDTypeIsValid device dtype
  ) =>
  -- | upper or lower triagonal
  Tri ->
  -- | input
  Tensor device dtype shape ->
  -- | eigenvalues and eigenvectors
  ( Tensor device dtype shape',
    Tensor device dtype shape''
  )
symeig :: forall (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape' ~ VectorOfSquare shape, shape'' ~ Square shape,
 SymeigDTypeIsValid device dtype) =>
Tri
-> Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device dtype shape'')
symeig Tri
upper Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.symeig_tbb Tensor device dtype shape
input Bool
True Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | symeigvalues
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,2,2])
-- >>> eigenVals = symeigvalues Upper t
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> :t eigenVals
-- eigenVals :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]
symeigvalues ::
  forall shape shape' dtype device.
  ( shape' ~ VectorOfSquare shape,
    SymeigDTypeIsValid device dtype
  ) =>
  -- | upper or lower triagonal
  Tri ->
  -- | input
  Tensor device dtype shape ->
  Tensor device dtype shape'
symeigvalues :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ VectorOfSquare shape, SymeigDTypeIsValid device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
symeigvalues Tri
upper Tensor device dtype shape
input = forall a b. (a, b) -> a
fst forall (shape'' :: [Nat]).
(Tensor device dtype shape', Tensor device dtype shape'')
symeig'
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper
    symeig' :: (Tensor device dtype shape', Tensor device dtype shape'')
    symeig' :: forall (shape'' :: [Nat]).
(Tensor device dtype shape', Tensor device dtype shape'')
symeig' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.symeig_tbb Tensor device dtype shape
input Bool
False Bool
boolUpper

data EigenVectors = EnableEigenVectors | DisableEigenVectors

class KnownEigenVectors a where
  enableEigenVectors :: Bool

instance KnownEigenVectors EnableEigenVectors where
  enableEigenVectors :: Bool
enableEigenVectors = Bool
True

instance KnownEigenVectors DisableEigenVectors where
  enableEigenVectors :: Bool
enableEigenVectors = Bool
False

type family ConditionalEigenVectors (eigenvectors :: EigenVectors) (n :: Nat) :: [Nat] where
  ConditionalEigenVectors EnableEigenVectors n = '[n, n]
  ConditionalEigenVectors DisableEigenVectors _ = '[0]

type family EigDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  EigDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  EigDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  EigDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | eig
-- Warning:
-- torch.eig is deprecated in favor of torch.linalg.eig and will be removed in a future PyTorch release.
-- torch.linalg.eig returns complex tensors of dtype cfloat or cdouble rather than real tensors mimicking complex tensors.
-- L, _ = torch.eig(A)
-- should be replaced with
-- L_complex = torch.linalg.eigvals(A)
-- and
-- L, V = torch.eig(A, eigenvectors=True)
-- should be replaced with
-- L_complex, V_complex = torch.linalg.eig(A) (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,3])
-- >>> (eigenVals,eigenVecs) = eig @EnableEigenVectors t
-- >>> dtype &&& shape $ eigenVals -- Skip warning
-- ...
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> :t eigenVals
-- eigenVals :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[3,3])
-- >>> :t eigenVecs
-- eigenVecs :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 3]
-- >>> (eigenVals,eigenVecs) = eig @DisableEigenVectors t
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[0])
-- >>> :t eigenVecs
-- eigenVecs :: Tensor '( 'D.CPU, 0) 'D.Float '[0]
eig ::
  forall eigenvectors n shape dtype device.
  ( KnownNat n,
    KnownEigenVectors eigenvectors,
    shape ~ ConditionalEigenVectors eigenvectors n,
    EigDTypeIsValid device dtype
  ) =>
  -- | input matrix
  Tensor device dtype '[n, n] ->
  -- | eigenvalues and eigenvectors
  ( Tensor device dtype '[n, 2],
    Tensor device dtype shape
  )
eig :: forall (eigenvectors :: EigenVectors) (n :: Nat) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownEigenVectors eigenvectors,
 shape ~ ConditionalEigenVectors eigenvectors n,
 EigDTypeIsValid device dtype) =>
Tensor device dtype '[n, n]
-> (Tensor device dtype '[n, 2], Tensor device dtype shape)
eig Tensor device dtype '[n, n]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.eig_tb Tensor device dtype '[n, n]
input (forall {k} (a :: k). KnownEigenVectors a => Bool
enableEigenVectors @eigenvectors)

type family SVDShapes (shape :: [Nat]) (reduced :: ReducedSVD) :: ([Nat], [Nat], [Nat]) where
  SVDShapes '[0, n] 'ThinSVD = '( '[0, 0], '[0], '[n, n])
  SVDShapes '[m, n] 'ThinSVD = '( '[m, Min m n], '[Min m n], '[n, Min m n])
  SVDShapes '[m, n] 'FullSVD = '( '[m, m], '[Min m n], '[n, n])
  SVDShapes '[b, 0, n] 'ThinSVD = '( '[b, 0, 0], '[b, 0], '[b, n, n])
  SVDShapes '[b, m, n] 'ThinSVD = '( '[b, m, Min m n], '[b, Min m n], '[b, n, Min m n])
  SVDShapes '[b, m, n] 'FullSVD = '( '[b, m, m], '[b, Min m n], '[b, n, n])
  SVDShapes _ _ = TypeError (Text "A singular value decomposition can only be computed for 2D matrices for at most one batch dimension.")

data ReducedSVD = ThinSVD | FullSVD

class KnownReducedSVD (reduced :: ReducedSVD) where
  reducedSVD :: Bool

instance KnownReducedSVD ThinSVD where
  reducedSVD :: Bool
reducedSVD = Bool
True

instance KnownReducedSVD FullSVD where
  reducedSVD :: Bool
reducedSVD = Bool
False

type family SVDDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SVDDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  SVDDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  SVDDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | Singular Value Decomposition
-- TODO: When `compute_uv` is `False`, backward cannot be performed since `u` and `v` from the forward pass are required for the backward operation. There is no way to encode in the types at this point in time. Thus, only `True` is supported currently.
--
-- This function returns a tuple `(u, s, v)`
-- which is the singular value decomposition of a input real matrix
-- or batches of real matrices input such that
-- `input = U×diag(S)×V^T`.
--
-- >>> a <- randn :: IO (CPUTensor 'D.Float '[3, 5])
-- >>> (u, s, v) = svd @'ThinSVD a
-- >>> dtype &&& shape $ u
-- (Float,[3,3])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[5,3])
-- >>> (u, s, v) = svd @'FullSVD a
-- >>> dtype &&& shape $ u
-- (Float,[3,3])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[5,5])
-- >>> a <- randn :: IO (CPUTensor 'D.Float '[5, 3])
-- >>> (u, s, v) = svd @'ThinSVD a
-- >>> dtype &&& shape $ u
-- (Float,[5,3])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[3,3])
-- >>> (u, s, v) = svd @'FullSVD a
-- >>> dtype &&& shape $ u
-- (Float,[5,5])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[3,3])
svd ::
  forall reduced shape shapeU shapeS shapeV dtype device.
  ( KnownReducedSVD reduced,
    '(shapeU, shapeS, shapeV) ~ SVDShapes shape reduced,
    SVDDTypeIsValid device dtype
  ) =>
  -- | (batched) input real matrix
  Tensor device dtype shape ->
  -- | (batched) output tuple of `u`, `s`, and `v`
  ( Tensor device dtype shapeU,
    Tensor device dtype shapeS,
    Tensor device dtype shapeV
  )
svd :: forall (reduced :: ReducedSVD) (shape :: [Nat]) (shapeU :: [Nat])
       (shapeS :: [Nat]) (shapeV :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownReducedSVD reduced,
 '(shapeU, shapeS, shapeV) ~ SVDShapes shape reduced,
 SVDDTypeIsValid device dtype) =>
Tensor device dtype shape
-> (Tensor device dtype shapeU, Tensor device dtype shapeS,
    Tensor device dtype shapeV)
svd Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.Managed.svd_tbb Tensor device dtype shape
input (forall (reduced :: ReducedSVD). KnownReducedSVD reduced => Bool
reducedSVD @reduced) Bool
True

type family CholeskyDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  CholeskyDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  CholeskyDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  CholeskyDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | cholesky
-- TODO: cholesky can throw if the input is not positive-definite.
-- Computes the Cholesky decomposition of a symmetric positive-definite matrix.
-- The operation supports batching.
--
-- Warning:
-- torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future PyTorch release.
-- L = torch.cholesky(A)
-- should be replaced with
-- L = torch.linalg.cholesky(A)
-- and
-- U = torch.cholesky(A, upper=True)
-- should be replaced with
-- U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[2,2])
-- >>> u = cholesky Upper (t `matmul` transpose2D t) -- Skip warning
-- ...
-- >>> dtype &&& shape $ u
-- (Float,[2,2])
-- >>> :t u
-- u :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2]
cholesky ::
  forall shape shape' dtype device.
  ( shape' ~ Square shape,
    CholeskyDTypeIsValid device dtype
  ) =>
  -- | indicate whether to return an upper or lower triangular matrix.
  Tri ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
cholesky :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ Square shape, CholeskyDTypeIsValid device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
cholesky Tri
upper Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_tb Tensor device dtype shape
input Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | choleskyInverse
-- Computes the inverse of a symmetric positive-definite matrix
-- using its Cholesky factor, returned, e.g., by `cholesky`.
-- Unlike `cholesky`, this operation does not support batching.
-- The inverse is computed using the LAPACK routine `?potri`.
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[2,2])
-- >>> tri = Upper
-- >>> u = cholesky tri (t `matmul` transpose2D t)
-- >>> dtype &&& shape $ choleskyInverse tri u
-- (Float,[2,2])
choleskyInverse ::
  forall n dtype device.
  ( 1 <= n,
    CholeskyDTypeIsValid device dtype
  ) =>
  -- | decides whether the upper or the lower triangular part of the input tensor is used
  Tri ->
  -- | the input 2-D tensor `u`, an upper or lower triangular Cholesky factor
  Tensor device dtype '[n, n] ->
  -- | the output 2-D tensor
  Tensor device dtype '[n, n]
choleskyInverse :: forall (n :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= n, CholeskyDTypeIsValid device dtype) =>
Tri -> Tensor device dtype '[n, n] -> Tensor device dtype '[n, n]
choleskyInverse Tri
upper Tensor device dtype '[n, n]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_inverse_tb Tensor device dtype '[n, n]
input Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | choleskySolve
-- Solves the system of linear equations represented by `a c = b`
-- using the Cholesky factor matrix `u` of `a` (returned, e.g., by `cholesky`),
-- where `a` is a positive semidefinite matrix.
-- The operation supports batching.
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,3])
-- >>> a = t `matmul` transpose2D t
-- >>> b <- rand :: IO (CPUTensor 'D.Float '[3,2])
-- >>> tri = Upper
-- >>> u = cholesky tri a
-- >>> dtype &&& shape $ choleskySolve tri b u
-- (Float,[3,2])
choleskySolve ::
  forall m_k m_m dtype device.
  ( Square m_m ~ m_m,
    FstSquareDim m_m ~ FstSquareDim m_k,
    1 <= FstSquareDim m_m,
    CholeskyDTypeIsValid device dtype
  ) =>
  -- | decides whether the upper or the lower triangular part of the input tensor `u` is used
  Tri ->
  -- | the (batched) RHS tensor `b`
  Tensor device dtype m_k ->
  -- | the (batched) input 2-D tensor `u`, an upper or lower triangular Cholesky factor
  Tensor device dtype m_m ->
  -- | the (batched) output 2-D tensor
  Tensor device dtype m_k
choleskySolve :: forall (m_k :: [Nat]) (m_m :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(Square m_m ~ m_m, FstSquareDim m_m ~ FstSquareDim m_k,
 1 <= FstSquareDim m_m, CholeskyDTypeIsValid device dtype) =>
Tri
-> Tensor device dtype m_k
-> Tensor device dtype m_m
-> Tensor device dtype m_k
choleskySolve Tri
upper Tensor device dtype m_k
b Tensor device dtype m_m
u =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_solve_ttb Tensor device dtype m_k
b Tensor device dtype m_m
u Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

type family SolveDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SolveDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  SolveDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  SolveDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | solve
-- Solves the system of linear equations represented by `a c = b` and also returns the LU decomposition of `a`.
-- `a` has to be a positive semidefinite matrix.
-- The operation supports batching.
--
-- Warning:
-- torch.solve is deprecated in favor of torch.linalg.solveand will be removed in a future PyTorch release.
-- torch.linalg.solve has its arguments reversed and does not return the LU factorization.
-- To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
-- X = torch.solve(B, A).solution
-- should be replaced with
-- X = torch.linalg.solve(A, B) (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[10,10])
-- >>> a = t `matmul` transpose2D t
-- >>> b <- rand :: IO (CPUTensor 'D.Float '[10,3])
-- >>> (c,lu) = solve b a
-- >>> dtype &&& shape $ c -- Skip warning
-- ...
-- >>> dtype &&& shape $ c
-- (Float,[10,3])
-- >>> dtype &&& shape $ lu
-- (Float,[10,10])
-- >>> :t c
-- c :: Tensor '( 'D.CPU, 0) 'D.Float '[10, 3]
-- >>> :t lu
-- lu :: Tensor '( 'D.CPU, 0) 'D.Float '[10, 10]
solve ::
  forall m_k m_m dtype device.
  ( Square m_m ~ m_m,
    FstSquareDim m_m ~ FstSquareDim m_k,
    1 <= FstSquareDim m_m,
    SolveDTypeIsValid device dtype
  ) =>
  -- | the (batched) RHS tensor `b`
  Tensor device dtype m_k ->
  -- | the (batched) positive semidefinite matrix `a`
  Tensor device dtype m_m ->
  -- | the (batched) outputs c and lu
  ( Tensor device dtype m_k,
    Tensor device dtype m_m
  )
solve :: forall (m_k :: [Nat]) (m_m :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(Square m_m ~ m_m, FstSquareDim m_m ~ FstSquareDim m_k,
 1 <= FstSquareDim m_m, SolveDTypeIsValid device dtype) =>
Tensor device dtype m_k
-> Tensor device dtype m_m
-> (Tensor device dtype m_k, Tensor device dtype m_m)
solve Tensor device dtype m_k
b Tensor device dtype m_m
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.solve_tt Tensor device dtype m_k
b Tensor device dtype m_m
a

-- | geqrf
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- `geqrf` computes a QR decomposition of the given `input` matrix,
-- but without constructing `Q` and `R` as explicit separate matrices.
-- Rather, this function directly calls the underlying LAPACK function `?geqrf`
-- which produces a tuple `(a, tau)` of intermediate results as defined in
-- the LAPACK documentation for `?geqrf`.
--
-- You can use `orgqr` on `(a, tau)` to compute the real orthogonal matrix `Q`,
-- but in general you may just want to use `qr` instead.
--
-- See the LAPACK documentation for `?geqrf` for further details,
-- https://software.intel.com/en-us/node/521004.
--
-- >>> (a, tau) = geqrf (ones :: CPUTensor 'D.Float '[3,4])
-- >>> dtype &&& shape $ a
-- (Float,[3,4])
-- >>> dtype &&& shape $ tau
-- (Float,[3])
-- >>> (a, tau) = geqrf (ones :: CPUTensor 'D.Float '[4,3])
-- >>> dtype &&& shape $ a
-- (Float,[4,3])
-- >>> dtype &&& shape $ tau
-- (Float,[3])
geqrf ::
  forall m n dtype device.
  -- | input matrix
  Tensor device dtype '[m, n] ->
  -- | tuple `(a, tau)` of output matrices
  ( Tensor device dtype '[m, n],
    Tensor device dtype '[Min m n]
  )
geqrf :: forall (m :: Nat) (n :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[m, n]
-> (Tensor device dtype '[m, n], Tensor device dtype '[Min m n])
geqrf Tensor device dtype '[m, n]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.geqrf_t Tensor device dtype '[m, n]
input

-- | orgqr
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- Computes the orthogonal matrix `Q` of a QR factorization
-- from the `(a, tau)` tuple returned by `geqrf`.
--
-- This directly calls the underlying LAPACK function `?orgqr`.
-- See the LAPACK documentation for `?orgqr` for further details,
-- https://software.intel.com/en-us/mkl-developer-reference-c-orgqr.
--
-- When libtorch-1.7, this function behavior is changed.
-- First dimention should be greater than second dimention.
--
-- >>> dtype &&& shape $ orgqr (ones :: CPUTensor 'D.Float '[4,3]) (ones :: CPUTensor 'D.Float '[3])
-- (Float,[4,3])
orgqr ::
  forall m n dtype device.
  ( KnownNat n,
    KnownNat m,
    n <= m
  ) =>
  Tensor device dtype '[m, n] ->
  Tensor device dtype '[n] ->
  Tensor device dtype '[m, n]
orgqr :: forall (m :: Nat) (n :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, n <= m) =>
Tensor device dtype '[m, n]
-> Tensor device dtype '[n] -> Tensor device dtype '[m, n]
orgqr Tensor device dtype '[m, n]
a Tensor device dtype '[n]
tau = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.orgqr_tt Tensor device dtype '[m, n]
a Tensor device dtype '[n]
tau

-- | sign
-- works for all dtypes
--
-- >>> dtype &&& shape $ sign (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sign ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
sign :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
sign Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sign_t Tensor device dtype shape
input

type family SetValue (shape :: [Nat]) (i :: Nat) (j :: Nat) :: [Nat] where
  SetValue '[] _ _ = '[]
  SetValue (x : xs) 0 j = j : xs
  SetValue (x : xs) i j = x : SetValue xs (i -1) j

type family GetValue (shape :: [Nat]) (i :: Nat) :: Nat where
  GetValue '[] _ = TypeError (Text "Can not find a element in the list.")
  GetValue (x : xs) 0 = x
  GetValue (x : xs) i = GetValue xs (i -1)

-- | Transpose
--
-- >>> :kind! Transpose '[3,2] 0 1
-- Transpose '[3,2] 0 1 :: [Natural]
-- = '[2, 3]
-- >>> :kind! Transpose '[3,2,1] 1 2
-- Transpose '[3,2,1] 1 2 :: [Natural]
-- = '[3, 1, 2]
type family Transpose (shape :: [Nat]) (dim0 :: Nat) (dim1 :: Nat) :: [Nat] where
  Transpose s d0 d1 = (SetValue (SetValue s d0 (GetValue s d1)) d1 (GetValue s d0))

-- | transpose
-- See "../../../../deps/pytorch/aten/src/ATen/native/TensorShape.cpp".
--
-- >>> dtype &&& shape $ transpose @0 @1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2,3])
-- >>> dtype &&& shape $ transpose @0 @1 (ones :: CPUTensor 'D.Float '[3,2,1])
-- (Float,[2,3,1])
-- >>> dtype &&& shape $ transpose @1 @2 (ones :: CPUTensor 'D.Float '[3,2,1])
-- (Float,[3,1,2])
transpose ::
  forall n m shape shape' dtype device.
  ( KnownNat n,
    KnownNat m,
    shape' ~ Transpose shape n m
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
transpose :: forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.transpose_tll Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @n) (forall (n :: Nat). KnownNat n => Int
natValI @m)

-- | transpose2d, special case for a 2D tensor
--
-- >>> dtype &&& shape $ transpose2D (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2,3])
transpose2D ::
  forall (i :: Nat) (j :: Nat) dtype device.
  -- | input
  Tensor device dtype '[i, j] ->
  -- | output
  Tensor device dtype '[j, i]
transpose2D :: forall (i :: Nat) (j :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[i, j] -> Tensor device dtype '[j, i]
transpose2D = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @0 @1

class KnownTri (tri :: Tri) where
  triVal :: Tri

instance KnownTri Upper where
  triVal :: Tri
triVal = Tri
Upper

instance KnownTri Lower where
  triVal :: Tri
triVal = Tri
Lower

type family DiagSize (tri :: Tri) (index :: Nat) (m :: Nat) (n :: Nat) :: Nat where
  DiagSize 'Upper i m n =
    If
      (i <=? n)
      (Min m (n - i))
      ( TypeError
          ( Text "For a matrix with shape "
              :<>: ShowType '[m, n]
              :<>: Text ", the maximum index for an upper diagonal is "
              :<>: ShowType n
              :<>: Text ", but asked for index "
              :<>: ShowType i
          )
      )
  DiagSize 'Lower i m n =
    If
      (i <=? m)
      (Min (m - i) n)
      ( TypeError
          ( Text "For a matrix with shape "
              :<>: ShowType '[m, n]
              :<>: Text ", the maximum index for a lower diagonal is "
              :<>: ShowType m
              :<>: Text ", but asked for index "
              :<>: ShowType i
          )
      )

type family DiagShape (tri :: Tri) (index :: Nat) (shape :: [Nat]) :: [Nat] where
  DiagShape _ i '[n] = '[n + i, n + i]
  DiagShape tri i '[m, n] = '[DiagSize tri i m n]
  DiagShape _ _ shape =
    TypeError
      ( Text "The input must be a matrix or a vector, but it has "
          :<>: ShowType (ListLength shape)
          :<>: Text " dimensions."
      )

-- | diag
--
-- >>> dtype &&& shape $ diag @'Upper @0 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2])
-- >>> dtype &&& shape $ diag @'Upper @1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[1])
-- >>> dtype &&& shape $ diag @'Lower @1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2])
diag ::
  forall tri index shape shape' device dtype.
  ( KnownTri tri,
    KnownNat index,
    StandardDTypeValidation device dtype,
    shape' ~ DiagShape tri index shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
diag :: forall (tri :: Tri) (index :: Nat) (shape :: [Nat])
       (shape' :: [Nat]) (device :: (DeviceType, Nat)) (dtype :: DType).
(KnownTri tri, KnownNat index,
 StandardDTypeValidation device dtype,
 shape' ~ DiagShape tri index shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
diag Tensor device dtype shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.tensor_diag_l Tensor device dtype shape
t forall a b. (a -> b) -> a -> b
$
    case forall (tri :: Tri). KnownTri tri => Tri
triVal @tri of
      Tri
Upper -> forall (n :: Nat). KnownNat n => Int
natValI @index
      Tri
Lower -> - forall (n :: Nat). KnownNat n => Int
natValI @index

-- | all
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.all.
--
-- >>> t = all (fromJust [False, False] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- False
--
-- >>> t = all (fromJust [False, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- False
--
-- >>> t = all (fromJust [True, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- True
all ::
  forall shape device.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool '[]
all :: forall (shape :: [Nat]) (device :: (DeviceType, Nat)).
Tensor device 'Bool shape -> Tensor device 'Bool '[]
all Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.all_t Tensor device 'Bool shape
input

-- | any
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.any.
--
-- >>> t = any (fromJust [False, False] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- False
--
-- >>> t = any (fromJust [False, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- True
--
-- >>> t = any (fromJust [True, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- True
any ::
  forall shape device.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool '[]
any :: forall (shape :: [Nat]) (device :: (DeviceType, Nat)).
Tensor device 'Bool shape -> Tensor device 'Bool '[]
any Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.any_t Tensor device 'Bool shape
input

data KeepOrDropDim = KeepDim | DropDim

class KnownKeepOrDropDim keepOrDropDim where
  keepOrDropDimVal :: Bool

instance KnownKeepOrDropDim KeepDim where
  keepOrDropDimVal :: Bool
keepOrDropDimVal = Bool
True

instance KnownKeepOrDropDim DropDim where
  keepOrDropDimVal :: Bool
keepOrDropDimVal = Bool
False

type family ConditionalDropDimension (shape :: [Nat]) (dim :: Nat) (keepOrDropDim :: KeepOrDropDim) :: [Nat] where
  ConditionalDropDimension '[] _ _ = TypeError (Text "The specified dimension is not available.")
  ConditionalDropDimension (x : xs) 0 KeepDim = 1 ': xs
  ConditionalDropDimension (x : xs) 0 DropDim = xs
  ConditionalDropDimension (x : xs) i keepOrDropDim = x ': ConditionalDropDimension xs (i - 1) keepOrDropDim

-- | allDim
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.all.
--
-- >>> t = fromJust [[True, True], [True, False], [True, True], [True, True]] :: CPUTensor 'D.Bool '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ allDim @1 @DropDim t
-- (Bool,([4],[True,False,True,True]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ allDim @1 @KeepDim t
-- (Bool,([4,1],[[True],[False],[True],[True]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ allDim @0 @DropDim t
-- (Bool,([2],[True,False]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ allDim @0 @KeepDim t
-- (Bool,([1,2],[[True,False]]))
allDim ::
  forall dim keepOrDropDim shape' shape device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim
  ) =>
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape'
allDim :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device 'Bool shape -> Tensor device 'Bool shape'
allDim Tensor device 'Bool shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.all_tlb Tensor device 'Bool shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | anyDim
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.any.
--
-- >>> t = fromJust [[True, True], [True, False], [True, True], [True, True]] :: CPUTensor 'D.Bool '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ anyDim @1 @DropDim t
-- (Bool,([4],[True,True,True,True]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ anyDim @1 @KeepDim t
-- (Bool,([4,1],[[True],[True],[True],[True]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ anyDim @0 @DropDim t
-- (Bool,([2],[True,True]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ anyDim @0 @KeepDim t
-- (Bool,([1,2],[[True,True]]))
anyDim ::
  forall dim keepOrDropDim shape' shape device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim
  ) =>
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape'
anyDim :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device 'Bool shape -> Tensor device 'Bool shape'
anyDim Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.any_tlb Tensor device 'Bool shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | dropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: get rid of IO by exposing the RNG state
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,2]
-- >>> t' <- dropout 0.5 False t
-- >>> dtype &&& shape $ t'
-- (Float,[3,2])
-- >>> t'' <- dropout 0.5 False t
-- >>> t ==. t''
-- Tensor Bool [3,2] [[ 1,  1],
--                    [ 1,  1],
--                    [ 1,  1]]
-- >>> t''' <- dropout 0.0 True t
-- >>> t ==. t'''
-- Tensor Bool [3,2] [[ 1,  1],
--                    [ 1,  1],
--                    [ 1,  1]]
-- >>> t'''' <- dropout 1.0 True t
-- >>> t''''
-- Tensor Float [3,2] [[ 0.0000,  0.0000],
--                     [ 0.0000,  0.0000],
--                     [ 0.0000,  0.0000]]
dropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  IO (Tensor device dtype shape)
dropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropout Double
p Bool
train Tensor device dtype shape
input = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | featureDropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: why not IO?
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> c = featureDropout 0.1 True (ones :: CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ c
-- (Float,[2,2])
featureDropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
featureDropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
featureDropout Double
p Bool
train Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.feature_dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | alphaDropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: why not IO?
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> c = alphaDropout 0.1 True (ones :: CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ c
-- (Float,[2,2])
alphaDropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
alphaDropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
alphaDropout Double
p Bool
train Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.alpha_dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | featureAlphaDropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: why not IO?
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> c = featureAlphaDropout 0.1 True (ones :: CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ c
-- (Float,[2,2])
featureAlphaDropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
featureAlphaDropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
featureAlphaDropout Double
p Bool
train Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.feature_alpha_dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | acos
--
-- >>> dtype &&& shape $ acos (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
acos ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
acos :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
acos Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.acos_t Tensor device dtype shape
input

-- | avgPool1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = avgPool1d @1 @1 @0 (ones :: CPUTensor 'D.Float '[1,3,4])
-- >>> shape t
-- [1,3,4]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4]
avgPool1d ::
  forall
    kernelSize
    stride
    padding
    channelSize
    inputSize
    batchSize
    outputSize
    dtype
    device.
  ( All KnownNat '[kernelSize, stride, padding, channelSize, inputSize, batchSize],
    ConvSideCheck inputSize kernelSize stride padding outputSize
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize]
avgPool1d :: forall (kernelSize :: Nat) (stride :: Nat) (padding :: Nat)
       (channelSize :: Nat) (inputSize :: Nat) (batchSize :: Nat)
       (outputSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[kernelSize, stride, padding, channelSize, inputSize, batchSize],
 ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
avgPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.avg_pool1d_tlllbb
      Tensor device dtype '[batchSize, channelSize, inputSize]
input
      (forall (n :: Nat). KnownNat n => Int
natValI @kernelSize)
      (forall (n :: Nat). KnownNat n => Int
natValI @stride)
      (forall (n :: Nat). KnownNat n => Int
natValI @padding)
      Bool
False
      Bool
True

-- | adaptiveAvgPool1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = adaptiveAvgPool1d @8 (ones :: CPUTensor 'D.Float '[1,3,16])
-- >>> shape t
-- [1,3,8]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8]
adaptiveAvgPool1d ::
  forall outputSize channelSize inputSize batchSize dtype device.
  (All KnownNat '[channelSize, inputSize, batchSize, outputSize]) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize]
adaptiveAvgPool1d :: forall (outputSize :: Nat) (channelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All KnownNat '[channelSize, inputSize, batchSize, outputSize] =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
adaptiveAvgPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool1d_tl Tensor device dtype '[batchSize, channelSize, inputSize]
input (forall (n :: Nat). KnownNat n => Int
natValI @outputSize)

-- | adaptiveMaxPool1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> tt = adaptiveMaxPool1d @8 (ones :: CPUTensor 'D.Float '[1,3,16])
-- >>> shape . fst $ tt
-- [1,3,8]
-- >>> :t tt
-- tt
--   :: (Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8],
--       Tensor '( 'D.CPU, 0) 'D.Int64 '[1, 3, 8])
adaptiveMaxPool1d ::
  forall outputSize channelSize inputSize batchSize dtype device.
  (All KnownNat '[channelSize, inputSize, batchSize, outputSize]) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize] ->
  -- | output
  ( Tensor device dtype '[batchSize, channelSize, outputSize],
    Tensor device 'D.Int64 '[batchSize, channelSize, outputSize]
  )
adaptiveMaxPool1d :: forall (outputSize :: Nat) (channelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All KnownNat '[channelSize, inputSize, batchSize, outputSize] =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> (Tensor device dtype '[batchSize, channelSize, outputSize],
    Tensor device 'Int64 '[batchSize, channelSize, outputSize])
adaptiveMaxPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.adaptive_max_pool1d_tl Tensor device dtype '[batchSize, channelSize, inputSize]
input (forall (n :: Nat). KnownNat n => Int
natValI @outputSize)

-- | addmv
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for beta and alpha?
--
-- >>> t = addmv 1 1 (ones :: CPUTensor 'D.Float '[3,2]) (zeros :: CPUTensor 'D.Float '[2]) (ones :: CPUTensor 'D.Float '[])
-- >>> dtype &&& shape $ t
-- (Float,[3])
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[3]
addmv ::
  forall shape' shape n m dtype device.
  ( KnownNat n,
    KnownNat m,
    shape' ~ Broadcast shape '[n]
  ) =>
  -- | beta
  Float ->
  -- | alpha
  Float ->
  -- | matrix
  Tensor device dtype '[n, m] ->
  -- | vector
  Tensor device dtype '[m] ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
addmv :: forall (shape' :: [Nat]) (shape :: [Nat]) (n :: Nat) (m :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Broadcast shape '[n]) =>
Float
-> Float
-> Tensor device dtype '[n, m]
-> Tensor device dtype '[m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
addmv Float
beta Float
alpha Tensor device dtype '[n, m]
mat Tensor device dtype '[m]
vec Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.addmv_tttss Tensor device dtype shape
input Tensor device dtype '[n, m]
mat Tensor device dtype '[m]
vec Float
beta Float
alpha

-- affine_grid_generator :: Tensor device dtype shape -> [Int] -> Tensor device dtype shape
-- affine_grid_generator _theta _size = unsafePerformIO $ (ATen.cast2 ATen.Managed.affine_grid_generator_tl) _theta _size

-- | allclose
--
-- >>> allclose 0.1 0.1 True (ones :: CPUTensor 'D.Float '[3,3]) (ones :: CPUTensor 'D.Float '[3,3])
-- True
allclose ::
  forall shape dtype device.
  -- | relative tolerance
  Double ->
  -- | absolute tolerance
  Double ->
  -- | whether or not NaN equals NaN
  Bool ->
  -- | input tensor
  Tensor device dtype shape ->
  -- | other input tensor
  Tensor device dtype shape ->
  -- | output
  Bool
allclose :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Double
-> Bool
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Bool
allclose Double
rtol Double
atol Bool
equalNaN Tensor device dtype shape
input Tensor device dtype shape
other =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor -> CDouble -> CDouble -> CBool -> IO CBool
ATen.Managed.allclose_ttddb Tensor device dtype shape
input Tensor device dtype shape
other Double
rtol Double
atol Bool
equalNaN

-- | argmax
-- See https://pytorch.org/docs/stable/torch.html#torch.argmax.
--
-- >>> t = fromJust [[0, 1], [-1, 2], [0, 1], [0, -2]] :: CPUTensor 'D.Float '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmax @1 @DropDim t
-- (Int64,([4],[1,1,1,0]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmax @1 @KeepDim t
-- (Int64,([4,1],[[1],[1],[1],[0]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmax @0 @DropDim t
-- (Int64,([2],[0,1]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmax @0 @KeepDim t
-- (Int64,([1,2],[[0,1]]))
argmax ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device 'D.Int64 shape'
argmax :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device 'Int64 shape'
argmax Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argmax_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | argmin
-- See https://pytorch.org/docs/stable/torch.html#torch.argmin.
--
-- >>> t = fromJust [[0, 1], [-1, 2], [0, 1], [0, -2]] :: CPUTensor 'D.Float '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmin @1 @DropDim t
-- (Int64,([4],[0,0,0,1]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmin @1 @KeepDim t
-- (Int64,([4,1],[[0],[0],[0],[1]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmin @0 @DropDim t
-- (Int64,([2],[1,3]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmin @0 @KeepDim t
-- (Int64,([1,2],[[1,3]]))
argmin ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device 'D.Int64 shape'
argmin :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device 'Int64 shape'
argmin Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argmin_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- as_strided :: Tensor device dtype shape -> [Int] -> [Int] -> Int -> Tensor device dtype shape
-- as_strided _input _size _stride _storage_offset = unsafePerformIO $ (ATen.cast4 ATen.Managed.as_strided_tlll) _input _size _stride _storage_offset

-- | asin
--
-- >>> dtype &&& shape $ asin (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
asin ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  Tensor device dtype shape ->
  Tensor device dtype shape
asin :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
asin Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.asin_t Tensor device dtype shape
input

-- | atan
--
-- >>> dtype &&& shape $ atan (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
atan ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  Tensor device dtype shape ->
  Tensor device dtype shape
atan :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
atan Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.atan_t Tensor device dtype shape
input

-- | baddbmm
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = baddbmm 1 1 (ones :: CPUTensor 'D.Float '[5,3,2]) (zeros :: CPUTensor 'D.Float '[5,2,4]) (ones :: CPUTensor 'D.Float '[])
-- >>> dtype &&& shape $ t
-- (Float,[5,3,4])
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 3, 4]
baddbmm ::
  forall shape' shape batchSize n m k dtype device.
  ( KnownNat n,
    KnownNat m,
    KnownNat k,
    shape' ~ Broadcast shape '[batchSize, n, m]
  ) =>
  -- | beta
  Float ->
  -- | alpha
  Float ->
  -- | first batch
  Tensor device dtype '[batchSize, n, k] ->
  -- | second batch
  Tensor device dtype '[batchSize, k, m] ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
baddbmm :: forall (shape' :: [Nat]) (shape :: [Nat]) (batchSize :: Nat)
       (n :: Nat) (m :: Nat) (k :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, KnownNat k,
 shape' ~ Broadcast shape '[batchSize, n, m]) =>
Float
-> Float
-> Tensor device dtype '[batchSize, n, k]
-> Tensor device dtype '[batchSize, k, m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
baddbmm Float
beta Float
alpha Tensor device dtype '[batchSize, n, k]
batch1 Tensor device dtype '[batchSize, k, m]
batch2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.baddbmm_tttss Tensor device dtype shape
input Tensor device dtype '[batchSize, n, k]
batch1 Tensor device dtype '[batchSize, k, m]
batch2 Float
beta Float
alpha

-- batch_norm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Double -> Double -> Bool -> Tensor device dtype shape
-- batch_norm _input _weight _bias _running_mean _running_var _training _momentum _eps _cudnn_enabled = unsafePerformIO $ (ATen.cast9 ATen.Managed.batch_norm_tttttbddb) _input _weight _bias _running_mean _running_var _training _momentum _eps _cudnn_enabled

-- bilinear :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- bilinear _input1 _input2 _weight _bias = unsafePerformIO $ (ATen.cast4 ATen.Managed.bilinear_tttt) _input1 _input2 _weight _bias

-- binary_cross_entropy_with_logits :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- binary_cross_entropy_with_logits _input _target _weight _pos_weight _reduction = unsafePerformIO $ (ATen.cast5 ATen.Managed.binary_cross_entropy_with_logits_ttttl) _input _target _weight _pos_weight _reduction

-- bincount :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- bincount _input _weights _minlength = unsafePerformIO $ (ATen.cast3 ATen.Managed.bincount_ttl) _input _weights _minlength

-- | batched matrix multiplication
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ bmm (ones :: CPUTensor 'D.Float '[5,3,2]) (zeros :: CPUTensor 'D.Float '[5,2,4])
-- (Float,[5,3,4])
bmm ::
  forall batchSize n m k dtype device.
  -- | input
  Tensor device dtype '[batchSize, n, k] ->
  -- | other input
  Tensor device dtype '[batchSize, k, m] ->
  -- | output
  Tensor device dtype '[batchSize, n, m]
bmm :: forall (batchSize :: Nat) (n :: Nat) (m :: Nat) (k :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[batchSize, n, k]
-> Tensor device dtype '[batchSize, k, m]
-> Tensor device dtype '[batchSize, n, m]
bmm Tensor device dtype '[batchSize, n, k]
input Tensor device dtype '[batchSize, k, m]
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.bmm_tt Tensor device dtype '[batchSize, n, k]
input Tensor device dtype '[batchSize, k, m]
other

-- | BroadcastTensorsImpl
--
-- >>> type Ty = BroadcastTensorsImpl '[] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Nothing
-- >>> type Ty = BroadcastTensorsImpl '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1]] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Just '( '[2, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = BroadcastTensorsImpl '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1], Tensor '( 'D.CPU, 0) 'D.Float '[5, 1, 1]] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Just '( '[5, 2, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = BroadcastTensorsImpl '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1], Tensor '( 'D.CPU, 0) 'D.Float '[1, 5, 1]] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Nothing
type family BroadcastTensorsImpl (tensors :: [a]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
  BroadcastTensorsImpl '[] 'Nothing = 'Nothing
  BroadcastTensorsImpl '[] ('Just '(reverseShape, dtype, device)) = 'Just '(Reverse reverseShape, dtype, device)
  BroadcastTensorsImpl (Tensor device dtype shape ': tensors) 'Nothing = BroadcastTensorsImpl tensors ('Just '(Reverse shape, dtype, device))
  BroadcastTensorsImpl (Tensor device dtype shape ': tensors) ('Just '(reverseShape', dtype, device)) = BroadcastTensorsImpl tensors (MaybeTriple (ComputeBroadcast (Reverse shape) reverseShape') ('Just dtype) ('Just device))
  BroadcastTensorsImpl (Tensor device dtype shape ': _) ('Just _) = Nothing

type family BroadcastTensorsCheck (tensors :: [a]) (result :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: [a] where
  BroadcastTensorsCheck tensors 'Nothing =
    TypeError
      ( Text "Cannot broadcast tensors due to incompatible shapes and/or dtypes: "
          :<>: ShowType tensors
      )
  BroadcastTensorsCheck tensors ('Just '(shape, dtype, device)) = HReplicateR (ListLength tensors) (Tensor device dtype shape)

type BroadcastTensors tensors =
  BroadcastTensorsCheck tensors (BroadcastTensorsImpl tensors 'Nothing)

-- | broadcast tensors
-- TODO: broadcastTensors returns garbage data and is hence broken
-- See https://pytorch.org/docs/stable/_modules/torch/functional.html#broadcast_tensors.
--
-- >>> x = ones :: CPUTensor 'D.Float '[1, 3]
-- >>> y = ones :: CPUTensor 'D.Float '[2, 1]
-- >>> z = ones :: CPUTensor 'D.Float '[5, 1, 1]
--
-- -- >>> x' :. y' :. z' :. HNil = broadcastTensors (x :. y :. z :. HNil)
-- -- >>> :type x'
-- -- x' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2, 3]
-- -- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[[Float]]]) $ x'
-- -- >>> :type y'
-- -- y' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2, 3]
-- -- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[[Float]]]) $ y'
-- -- >>> :type z'
-- -- z' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2, 3]
-- -- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[[Float]]]) $ z'
broadcastTensors ::
  forall tensors tensors'.
  ( tensors' ~ BroadcastTensors tensors,
    ATen.Castable (HList tensors) [D.ATenTensor],
    ATen.Castable (HList tensors') [D.ATenTensor]
  ) =>
  -- | input list of tensors
  HList tensors ->
  -- | output list of tensors
  HList tensors'
broadcastTensors :: forall {k} (tensors :: [k]) (tensors' :: [k]).
(tensors' ~ BroadcastTensors tensors,
 Castable (HList tensors) [ForeignPtr Tensor],
 Castable (HList tensors') [ForeignPtr Tensor]) =>
HList tensors -> HList tensors'
broadcastTensors HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr TensorList -> IO (ForeignPtr TensorList)
ATen.Managed.broadcast_tensors_l HList tensors
tensors

type family CatImpl (dim :: Nat) (tensors :: [a]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
  CatImpl _ '[] acc = acc
  CatImpl dim (Tensor device dtype shape ': tensors) acc = CatImpl dim tensors (MaybeTriple (ComputeCatShape dim shape acc) (ComputeCatDType dtype acc) (ComputeCatDevice device acc))

type family ComputeCatShape (dim :: Nat) (shape :: [Nat]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe [Nat] where
  ComputeCatShape 0 (x ': xs) Nothing = Just (x ': xs)
  ComputeCatShape dim (x ': xs) Nothing = AppendToMaybe x (ComputeCatShape (dim - 1) xs Nothing)
  ComputeCatShape 0 (x ': xs) (Just '(y ': xs, _, _)) = Just ((x + y) ': xs)
  ComputeCatShape dim (x ': xs) (Just '(x ': ys, dtype, device)) = AppendToMaybe x (ComputeCatShape (dim - 1) xs (Just '(ys, dtype, device)))
  ComputeCatShape _ _ _ = Nothing

type family ComputeCatDType (dtype :: D.DType) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe D.DType where
  ComputeCatDType dtype Nothing = Just dtype
  ComputeCatDType dtype (Just '(_, dtype, _)) = Just dtype
  ComputeCatDType _ _ = Nothing

type family ComputeCatDevice (device :: (D.DeviceType, Nat)) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe (D.DeviceType, Nat) where
  ComputeCatDevice device Nothing = Just device
  ComputeCatDevice device (Just '(_, _, device)) = Just device
  ComputeCatDevice _ _ = Nothing

type family CatCheck (res :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: ([Nat], D.DType, (D.DeviceType, Nat)) where
  CatCheck 'Nothing = TypeError (Text "Concatenation impossible.")
  CatCheck ('Just '(shape, dtype, device)) = '(shape, dtype, device)

-- | Cat
--
-- >>> type Ty = Cat 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[1]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[1], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[1], Tensor '( 'D.CPU, 0) 'D.Float '[2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 3]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[3, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 1 '[Tensor '( 'D.CPU, 0) 'D.Float '[3, 1], Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[3, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 1 '[Tensor '( 'D.CPU, 0) 'D.Float '[2, 5, 4, 2], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 4, 2], Tensor '( 'D.CPU, 0) 'D.Float '[2, 3, 4, 2], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 4, 2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[2, 10, 4, 2], 'D.Float, '( 'D.CPU, 0))
type Cat dim tensors = CatCheck (CatImpl dim tensors Nothing)

-- | cat
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> t' = cat @0 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([2,2],[[1.0,1.0],[1.0,1.0]]))
-- >>> t' = cat @1 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([2,2],[[1.0,1.0],[1.0,1.0]]))
-- >>> t' = cat @0 (t :. t :. t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[6, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([6,2],[[1.0,1.0],[1.0,1.0],[1.0,1.0],[1.0,1.0],[1.0,1.0],[1.0,1.0]]))
-- >>> t' = cat @1 (t :. t :. t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 6]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([2,6],[[1.0,1.0,1.0,1.0,1.0,1.0],[1.0,1.0,1.0,1.0,1.0,1.0]]))
cat ::
  forall dim shape dtype device tensors.
  ( KnownNat dim,
    '(shape, dtype, device) ~ Cat dim tensors,
    ATen.Castable (HList tensors) [D.ATenTensor]
  ) =>
  -- | input list of tensors
  HList tensors ->
  -- | output tensor
  Tensor device dtype shape
cat :: forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
cat HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.cat_ll HList tensors
tensors (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)

-- chain_matmul :: [Tensor device dtype shape] -> Tensor device dtype shape
-- chain_matmul _matrices = unsafePerformIO $ (ATen.cast1 ATen.Managed.chain_matmul_l) _matrices

type family ChunkImpl (chunkShapes :: Maybe [[Nat]]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) :: Maybe a where
  ChunkImpl (Just '[]) _ _ = Just '[]
  ChunkImpl (Just (shape ': shapes)) dtype device = AppendToMaybe (Tensor device dtype shape) (ChunkImpl (Just shapes) dtype device)
  ChunkImpl Nothing _ _ = Nothing

type family ChunkCheck (shape :: [Nat]) (dim :: Nat) (result :: Maybe a) :: a where
  ChunkCheck shape dim Nothing = DimOutOfBound shape dim
  ChunkCheck _ _ (Just result) = result

type family ComputeChunksChunkGo (n' :: Nat) (r :: Nat) (cmp :: Ordering) (cmp' :: Ordering) :: [Nat] where
  ComputeChunksChunkGo n' r GT _ = n' ': ComputeChunksChunkGo n' (r - n') (CmpNat (r - n') n') (CmpNat (r - n') 0)
  ComputeChunksChunkGo n' r EQ _ = n' ': ComputeChunksChunkGo n' (r - n') (CmpNat (r - n') n') (CmpNat (r - n') 0)
  ComputeChunksChunkGo n' r _ GT = '[r]
  ComputeChunksChunkGo n' _ _ _ = '[]

type family ComputeChunksChunkGo0 (n' :: Nat) (chunks :: Nat) :: [Nat] where
  ComputeChunksChunkGo0 _ 0 = '[]
  ComputeChunksChunkGo0 n' chunks = n' ': (ComputeChunksChunkGo0 n' (chunks - 1))

type family ComputeChunks (n :: Maybe Nat) (chunks :: Nat) :: Maybe [Nat] where
  ComputeChunks (Just n) chunks = Just (ComputeChunks' n chunks (Mod n chunks))
  ComputeChunks Nothing _ = Nothing

type family ComputeChunks' (n :: Nat) (chunks :: Nat) (m :: Nat) :: [Nat] where
  ComputeChunks' n chunks 0 = ComputeChunksChunkGo0 (Div n chunks) chunks
  ComputeChunks' n chunks _ = ComputeChunksChunkGo (Div (n + chunks - 1) chunks) n (CmpNat n (Div (n + chunks - 1) chunks)) (CmpNat n 0)

type family ChunkShapesImpl (chunks :: Maybe [Nat]) (dim :: Nat) (shape :: [Nat]) :: Maybe [[Nat]] where
  ChunkShapesImpl (Just (n ': ns)) dim shape = AppendToMaybe' (ReplaceDim dim shape n) (ChunkShapesImpl (Just ns) dim shape)
  ChunkShapesImpl (Just '[]) _ _ = Just '[]
  ChunkShapesImpl Nothing _ _ = Nothing

type ChunkShapes chunks dim shape = ChunkShapesImpl (ComputeChunks (ExtractDim dim shape) chunks) dim shape

type Chunk chunks dim shape dtype device = ChunkCheck shape dim (ChunkImpl (ChunkShapes chunks dim shape) dtype device)

-- | chunk
--
-- -- >>> :type chunk @3 @1 (ones :: CPUTensor 'D.Float '[2, 2])
-- -- chunk @3 @1 (ones :: CPUTensor 'D.Float '[2, 2])
-- --   :: HList
-- --        '[Tensor '( 'D.CPU, 0) 'D.Float '[2, 1],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[2, 1]]
-- >>> t0 :. t1 :. HNil = chunk @3 @1 (ones :: CPUTensor 'D.Float '[2, 2])
-- >>> dtype &&& shape $ t0
-- (Float,[2,1])
-- >>> dtype &&& shape $ t1
-- (Float,[2,1])
--
-- -- >>> :type chunk @3 @1 (ones :: CPUTensor 'D.Float '[1, 0, 3])
-- -- chunk @3 @1 (ones :: CPUTensor 'D.Float '[1, 0, 3])
-- --   :: HList
-- --        '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 0, 3],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[1, 0, 3],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[1, 0, 3]]
-- >>> t0 :. t1 :. t2 :. HNil = chunk @3 @1 (ones :: CPUTensor 'D.Float '[1, 0, 3])
-- >>> dtype &&& shape $ t0
-- (Float,[1,0,3])
-- >>> dtype &&& shape $ t1
-- (Float,[1,0,3])
-- >>> dtype &&& shape $ t2
-- (Float,[1,0,3])
--
-- -- >>> :type chunk @6 @0 (ones :: CPUTensor 'D.Float '[19, 4])
-- -- chunk @6 @0 (ones :: CPUTensor 'D.Float '[19, 4])
-- --   :: HList
-- --        '[Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[3, 4]]
-- >>> t0 :. t1 :. t2 :. t3 :. t4 :. HNil = chunk @6 @0 (ones :: CPUTensor 'D.Float '[19, 4])
-- >>> dtype &&& shape $ t0
-- (Float,[4,4])
-- >>> dtype &&& shape $ t1
-- (Float,[4,4])
-- >>> dtype &&& shape $ t2
-- (Float,[4,4])
-- >>> dtype &&& shape $ t3
-- (Float,[4,4])
-- >>> dtype &&& shape $ t4
-- (Float,[3,4])
chunk ::
  forall chunks dim shape dtype device tensorChunks.
  ( KnownNat chunks,
    KnownNat dim,
    tensorChunks ~ Chunk chunks dim shape dtype device,
    ATen.Castable (HList tensorChunks) [D.ATenTensor]
  ) =>
  -- | input tensor
  Tensor device dtype shape ->
  -- | output list of tensors
  HList tensorChunks
chunk :: forall {k} (chunks :: Nat) (dim :: Nat) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat))
       (tensorChunks :: [k]).
(KnownNat chunks, KnownNat dim,
 tensorChunks ~ Chunk chunks dim shape dtype device,
 Castable (HList tensorChunks) [ForeignPtr Tensor]) =>
Tensor device dtype shape -> HList tensorChunks
chunk Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr TensorList)
ATen.Managed.chunk_tll Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @chunks :: Int) (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)

-- | clamp
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for the minimum and maximum values?
--
-- >>> dtype &&& shape $ clamp 0 1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
clamp ::
  forall shape dtype device a.
  (D.Scalar a) =>
  -- | minimum value
  a ->
  -- | maximum value
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
clamp :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
clamp a
min a
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_tss Tensor device dtype shape
input a
min a
max

-- | clampMax
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for the maximum value?
--
-- >>> dtype &&& shape $ clampMax 1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
clampMax ::
  forall shape dtype device a.
  (D.Scalar a) =>
  -- | maximum value
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
clampMax :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
clampMax a
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_max_ts Tensor device dtype shape
input a
max

-- | clampMin
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for the minimum value?
--
-- >>> dtype &&& shape $ clampMin 0 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
clampMin ::
  forall shape dtype device a.
  (D.Scalar a) =>
  -- | minimum value
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
clampMin :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
clampMin a
min Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_min_ts Tensor device dtype shape
input a
min

-- | cudnnIsAcceptable
-- TODO: calling this probably makes only sense when the device is CUDA
cudnnIsAcceptable ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Bool
cudnnIsAcceptable :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.cudnn_is_acceptable_t Tensor device dtype shape
input

-- constant_pad_nd :: Tensor device dtype shape -> [Int] -> Float -> Tensor device dtype shape
-- constant_pad_nd _input _pad _value = unsafePerformIO $ (ATen.cast3 ATen.Managed.constant_pad_nd_tls) _input _pad _value

constantPadNd1d ::
  forall (pad :: (Nat, Nat)) n dtype device.
  (All KnownNat '[Torch.Typed.Auxiliary.Fst pad, Torch.Typed.Auxiliary.Snd pad, n]) =>
  Float ->
  Tensor device dtype '[n] ->
  Tensor device dtype '[n + Torch.Typed.Auxiliary.Fst pad + Torch.Typed.Auxiliary.Snd pad]
constantPadNd1d :: forall (pad :: (Nat, Nat)) (n :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
All KnownNat '[Fst pad, Snd pad, n] =>
Float
-> Tensor device dtype '[n]
-> Tensor device dtype '[(n + Fst pad) + Snd pad]
constantPadNd1d Float
value Tensor device dtype '[n]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.constant_pad_nd_tls
      Tensor device dtype '[n]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst pad), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd pad)] :: [Int])
      Float
value

-- convolution :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> Bool -> [Int] -> Int -> Tensor device dtype shape
-- convolution _input _weight _bias _stride _padding _dilation _transposed _output_padding _groups = unsafePerformIO $ (ATen.cast9 ATen.Managed.convolution_tttlllbll) _input _weight _bias _stride _padding _dilation _transposed _output_padding _groups

type ConvSideCheck (inputSize :: Nat) (kernelSize :: Nat) (stride :: Nat) (padding :: Nat) (outputSize :: Nat) =
  ( -- kernel size and stride must be > 0
    1 <= kernelSize,
    1 <= stride,
    -- kernel size can't be greater than actual input size
    -- ToDo: Do not use '>=' on constraint to avoid reduction-stack-overflow.
    (kernelSize - 1) <= (inputSize + (2 * padding)),
    -- output size must be greater than 0
    1 <= outputSize,
    -- output formulation:
    outputSize ~ ConvOutputSize inputSize kernelSize stride padding
  )

-- | ConvOutputSize
--
-- >>> :kind! ConvOutputSize 4 1 1 0
-- ConvOutputSize 4 1 1 0 :: Natural
-- = 4
type family ConvOutputSize (inputSize :: Nat) (kernelSize :: Nat) (stride :: Nat) (padding :: Nat) :: Nat where
  ConvOutputSize inputSize kernelSize stride padding = (Div ((inputSize