{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Functional where

import Control.Arrow ((&&&))
import Data.Finite
import qualified Data.Int as I
import Data.Kind
  ( Constraint,
    Type,
  )
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.ForeignPtr
import GHC.Generics (Generic)
import GHC.Natural (Natural)
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.Functional
  ( Reduction (..),
    Tri (..),
    isUpper,
    kOne,
  )
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen.Managed
import qualified Torch.Internal.Managed.Native as ATen.Managed
import qualified Torch.Internal.Managed.Type.Scalar as ATen.Managed
import qualified Torch.Internal.Managed.Type.Tensor as ATen.Managed
import qualified Torch.Internal.Managed.Type.Tuple as ATen.Managed
import qualified Torch.Internal.Type as ATen
import qualified Torch.Scalar as D
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import qualified Torch.TensorOptions as D
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Tensor
import Prelude hiding
  ( abs,
    acos,
    acosh,
    all,
    any,
    asin,
    asinh,
    atan,
    atanh,
    ceil,
    cos,
    cosh,
    exp,
    floor,
    isNaN,
    log,
    max,
    min,
    round,
    sin,
    sinh,
    tan,
    tanh,
  )

-- $setup
--
-- >>> :set -XOverloadedLists

-- | Computes the bitwise NOT of the given input tensor.
-- The input tensor must be of integral or Boolean types.
-- For bool tensors, it computes the logical NOT.
--
-- >>> dtype &&& shape $ bitwiseNot (ones :: CPUTensor 'D.Bool [3,3])
-- (Bool,[3,3])
bitwiseNot ::
  forall device shape.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape
bitwiseNot :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape -> Tensor device 'Bool shape
bitwiseNot Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.bitwise_not_t Tensor device 'Bool shape
input

-- | Computes the element-wise logical NOT of the given input tensor.
-- If not specified, the output tensor will have the bool dtype.
-- If the input tensor is not a bool tensor, zeros are treated as False and non-zeros are treated as True.
logicalNot ::
  forall device shape.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape
logicalNot :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalNot Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_not_t Tensor device 'Bool shape
input

logicalXor ::
  forall device shape.
  -- | self
  Tensor device 'D.Bool shape ->
  -- | other
  Tensor device 'D.Bool shape ->
  Tensor device 'D.Bool shape
logicalXor :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalXor Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_xor_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other

logicalAnd ::
  forall device shape.
  -- | self
  Tensor device 'D.Bool shape ->
  -- | other
  Tensor device 'D.Bool shape ->
  Tensor device 'D.Bool shape
logicalAnd :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalAnd Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_and_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other

logicalOr ::
  forall device shape.
  -- | self
  Tensor device 'D.Bool shape ->
  -- | other
  Tensor device 'D.Bool shape ->
  Tensor device 'D.Bool shape
logicalOr :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalOr Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_or_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other

type family SumDType (dtype :: D.DType) :: D.DType where
  SumDType D.Bool = D.Int64
  SumDType D.UInt8 = D.Int64
  SumDType D.Int8 = D.Int64
  SumDType D.Int16 = D.Int64
  SumDType D.Int32 = D.Int64
  SumDType D.Int64 = D.Int64
  SumDType D.Half = D.Half
  SumDType D.Float = D.Float
  SumDType D.Double = D.Double

type family SumDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SumDTypeIsValid '( 'D.CPU, 0) dtype = DTypeIsNotHalf '( 'D.CPU, 0) dtype
  SumDTypeIsValid '( 'D.CUDA, _) dtype = ()
  SumDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | sumAll
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Bool '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.UInt8 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int8 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int16 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int32 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Int) $ sumAll (ones :: CPUTensor 'D.Int64 '[2, 3])
-- (Int64,([],6))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Float) $ sumAll (ones :: CPUTensor 'D.Float '[2, 3])
-- (Float,([],6.0))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: Double) $ sumAll (ones :: CPUTensor 'D.Double '[2, 3])
-- (Double,([],6.0))
sumAll ::
  forall shape dtype' dtype device.
  ( SumDTypeIsValid device dtype,
    dtype' ~ SumDType dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype' '[]
sumAll :: forall (shape :: [Nat]) (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' '[]
sumAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sum_t Tensor device dtype shape
input

-- | sumDim
--
-- >>> dtype &&& shape $ sumDim @0 (ones :: CPUTensor 'D.Float '[3,4,5])
-- (Float,[4,5])
-- >>> sumDim @1 (ones :: CPUTensor 'D.Float '[2,4])
-- Tensor Float [2] [ 4.0000   ,  4.0000   ]
sumDim ::
  forall d shape shape' dtype dtype' device.
  ( KnownNat d,
    shape' ~ DropValue shape d,
    SumDTypeIsValid device dtype,
    dtype' ~ SumDType dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype' shape'
sumDim :: forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
sumDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.sum_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @d)

-- | abs
--
-- >>> dtype &&& shape $ abs (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
abs ::
  forall shape dtype device.
  (StandardDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
abs :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
abs Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.abs_t Tensor device dtype shape
input

-- | ceil
--
-- >>> dtype &&& shape $ ceil (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
ceil ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
ceil :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
ceil Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.ceil_t Tensor device dtype shape
input

-- | floor
--
-- >>> dtype &&& shape $ floor (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
floor ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
floor :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
floor Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.floor_t Tensor device dtype shape
input

type family MinMaxDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  MinMaxDTypeIsValid '( 'D.CPU, 0) dtype = DTypeIsNotHalf '( 'D.CPU, 0) dtype
  MinMaxDTypeIsValid '( 'D.CUDA, _) dtype = ()
  MinMaxDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | min
--
-- >>> dtype &&& shape $ min (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
min ::
  forall shape dtype device.
  ( MinMaxDTypeIsValid device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
min :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(MinMaxDTypeIsValid device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
min Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.min_t Tensor device dtype shape
input

-- | max
--
-- >>> dtype &&& shape $ max (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
max ::
  forall shape dtype device.
  ( MinMaxDTypeIsValid device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
max :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(MinMaxDTypeIsValid device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.max_t Tensor device dtype shape
input

type family MeanDTypeValidation (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  MeanDTypeValidation '(deviceType, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '(deviceType, deviceIndex) dtype,
      DTypeIsNotHalf '(deviceType, deviceIndex) dtype
    )

-- | Computes the mean while carrying out a full reduction of all tensor dimensions.
--
-- >>> meanAll (ones :: CPUTensor 'D.Float '[])
-- Tensor Float []  1.0000
-- >>> meanAll (zeros :: CPUTensor 'D.Float '[2,2])
-- Tensor Float []  0.0000
meanAll ::
  forall shape dtype device.
  ( MeanDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
meanAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
meanAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mean_t Tensor device dtype shape
input

-- | Computes the mean while carrying out a full reduction of all tensor dimensions.
-- This version is not restricted and can return NaN.
--
-- >>> unsafeMeanAll (ones :: CPUTensor 'D.Float '[])
-- Tensor Float []  1.0000
-- >>> unsafeMeanAll (ones :: CPUTensor 'D.Float '[0])
-- Tensor Float [] NaN
-- >>> unsafeMeanAll (zeros :: CPUTensor 'D.Float '[2,2])
-- Tensor Float []  0.0000
unsafeMeanAll ::
  forall shape dtype device.
  MeanDTypeValidation device dtype =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
unsafeMeanAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
MeanDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype '[]
unsafeMeanAll
  Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mean_t Tensor device dtype shape
input

-- | Computes the mean and reduces the tensor over the specified dimension.
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ meanDim @0 t
-- (Float,[4,5])
-- >>> dtype &&& shape $ meanDim @1 t
-- (Float,[3,5])
-- >>> dtype &&& shape $ meanDim @2 t
-- (Float,[3,4])
meanDim ::
  forall dim shape' shape dtype device.
  ( KnownNat dim,
    shape' ~ DropValue shape dim,
    MeanDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
meanDim :: forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
 MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
meanDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)

-- | Computes the mean and reduces the tensor over the specified dimension.
--
-- >>> import Torch.Typed.Factories
-- >>> import Data.Default.Class
-- >>> t = def :: NamedTensor '( D.CPU, 0) 'D.Float '[Vector 3, Vector 4, Vector 5]
-- >>> dtype &&& shape $ meanNamedDim @(Vector 4) t
-- (Float,[3,5])
meanNamedDim ::
  forall dim shape' shape dtype device.
  ( KnownNat (FindDim dim shape),
    shape' ~ DropNamedValue shape dim,
    MeanDTypeValidation device dtype
  ) =>
  -- | input
  NamedTensor device dtype shape ->
  -- | output
  NamedTensor device dtype shape'
meanNamedDim :: forall (dim :: Size) (shape' :: Shape) (shape :: Shape)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat (FindDim dim shape), shape' ~ DropNamedValue shape dim,
 MeanDTypeValidation device dtype) =>
NamedTensor device dtype shape -> NamedTensor device dtype shape'
meanNamedDim NamedTensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tl NamedTensor device dtype shape
input Int
_dim
  where
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @(FindDim dim shape)

-- | Computes the mean and optionally reduces the tensor over the specified dimension.
--
-- See https://pytorch.org/docs/stable/torch.html#torch.mean for more information.
--
-- >>> t = fromJust [[5, 1], [3, 2], [4, 1], [2, 7]] :: CPUTensor 'D.Float '[4, 2]
-- >>> mean @0 @KeepDim t
-- Tensor Float [1,2] [[ 3.5000   ,  2.7500   ]]
mean ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    MeanDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape'
mean :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
mean Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | Computes the median while carrying out a full reduction of all tensor dimensions.
--
-- >>> dtype &&& shape $ medianAll (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
medianAll ::
  forall shape dtype device.
  ( StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
medianAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
medianAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.median_t Tensor device dtype shape
input

-- | Computes the median and reduces the tensor over the specified dimension.
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ fst $ medianDim @0 t
-- (Float,[4,5])
-- >>> dtype &&& shape $ fst $ medianDim @1 t
-- (Float,[3,5])
-- >>> dtype &&& shape $ fst $ medianDim @2 t
-- (Float,[3,4])
medianDim ::
  forall dim shape' shape dtype device.
  ( KnownNat dim,
    shape' ~ DropValue shape dim,
    StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  ( Tensor device dtype shape',
    Tensor device 'D.Int64 shape'
  )
medianDim :: forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
 StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
medianDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.median_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)

-- | Computes the median and optionally reduces the tensor over the specified dimension.
--
-- See https://pytorch.org/docs/stable/torch.html#torch.median for more information.
--
-- >>> t = fromJust [[5, 1], [3, 2], [4, 1], [2, 7]] :: CPUTensor 'D.Float '[4, 2]
--
-- -- libtorch 1.7.0
-- -- (Tensor Float [1,2] [[ 3.0000   ,  1.0000   ]],Tensor Int64 [1,2] [[ 1,  0]])
-- -- libtorch 1.8.0
-- >>> median @0 @KeepDim t
-- (Tensor Float [1,2] [[ 3.0000   ,  1.0000   ]],Tensor Int64 [1,2] [[ 1,  2]])
median ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  (Tensor device dtype shape', Tensor device 'D.Int64 shape')
median :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
median Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.median_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | Returns a tuple '(modes, indices)' where 'modes' is the mode value of each row of the 'input' tensor
-- in the given dimension 'dim', i.e. a value which appears most often in that row,
-- and 'indices' is the index location of each mode value found.
--
-- See https://pytorch.org/docs/stable/torch.html#torch.mode for more information.
--
-- >>> t = fromJust [[0, 5], [0, 2], [3, 5]] :: CPUTensor 'D.Int64 '[3, 2]
--
-- >>> (modes :: CPUTensor 'D.Int64 '[2], indices :: CPUTensor 'D.Int64 '[2]) = mode @0 @DropDim t
-- >>> (dtype modes, shape modes, D.asValue (toDynamic modes) :: [Int])
-- (Int64,[2],[0,5])
-- >>> (dtype indices, shape indices, D.asValue (toDynamic indices) :: [Int])
-- (Int64,[2],[1,2])
--
-- >>> t = fromJust [[0, 0], [0, 1], [3, 3]] :: CPUTensor 'D.Float '[3, 2]
--
-- >>> (modes :: CPUTensor 'D.Float '[3,1], indices :: CPUTensor 'D.Int64 '[3,1]) = mode @1 @KeepDim t
-- >>> (dtype modes, shape modes, D.asValue (toDynamic modes) :: [[Float]])
-- (Float,[3,1],[[0.0],[0.0],[3.0]])
-- >>> (dtype indices, shape indices, D.asValue (toDynamic indices) :: [[Int]])
-- (Int64,[3,1],[[1],[0],[1]])
mode ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype,
    AllDimsPositive shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  (Tensor device dtype shape', Tensor device 'D.Int64 shape')
mode :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
mode Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.mode_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | addScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ addScalar 1 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
addScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
addScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
addScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.add_ts Tensor device dtype shape
input a
a

-- | subScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ subScalar 1 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
subScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
subScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
subScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.sub_ts Tensor device dtype shape
input a
a

-- | mulScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ mulScalar 2 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
mulScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
mulScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.mul_ts Tensor device dtype shape
input a
a

-- | divScalar
-- TODO: what dtypes is this defined for?
-- TODO: what scalar types is this defined for?
--
-- >>> dtype &&& shape $ divScalar 2 (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
divScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | scalar input
  a ->
  -- | tensor input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
divScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.div_ts Tensor device dtype shape
input a
a

-- | powScalar
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ powScalar 2 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
powScalar ::
  forall a shape dtype device.
  D.Scalar a =>
  -- | power
  a ->
  -- | input tensor
  Tensor device dtype shape ->
  -- | output tensor
  Tensor device dtype shape
powScalar :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
powScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.pow_ts Tensor device dtype shape
input a
a

-- | erf
--
-- >>> dtype &&& shape $ erf (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
erf ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
erf :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
erf Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.erf_t Tensor device dtype shape
input

-- | exp
--
-- >>> dtype &&& shape $ exp (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
exp ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
exp :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
exp Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.exp_t Tensor device dtype shape
input

-- | log1p
--
-- >>> dtype &&& shape $ log1p (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
log1p ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
log1p :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log1p Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log1p_t Tensor device dtype shape
input

-- | log2
-- >>> dtype &&& shape $ log2 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
log2 ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
log2 :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log2_t Tensor device dtype shape
input

-- | log10
--
-- >>> dtype &&& shape $ log10 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
log10 ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
log10 :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log10 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log10_t Tensor device dtype shape
input

-- | pow
-- this operation supports broadcasting
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ pow (2 :: CPUTensor 'D.Float '[]) (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
pow ::
  forall shape'' shape shape' dtype device.
  ( BasicArithmeticDTypeIsValid device dtype,
    shape'' ~ Broadcast shape shape'
  ) =>
  -- | power
  Tensor device dtype shape ->
  -- | input tensor
  Tensor device dtype shape' ->
  -- | output tensor
  Tensor device dtype shape''
pow :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 shape'' ~ Broadcast shape shape') =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
pow Tensor device dtype shape
exponent Tensor device dtype shape'
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.pow_tt Tensor device dtype shape'
input Tensor device dtype shape
exponent

-- | relu activation function
--
-- >>> dtype &&& shape $ relu (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
relu ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
relu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
relu t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.relu_t (forall a. a -> Wrap a
Wrap t
input)

-- | selu
--
-- >>> dtype &&& shape $ selu (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
selu ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
selu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
selu Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.selu_t Tensor device dtype shape
input

-- | mish
-- `mish` is a smooth activation function, see https://arxiv.org/abs/1908.08681 for details.
--
-- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[Float]]) $ mish (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,([3,2],[[0.86509836,0.86509836],[0.86509836,0.86509836],[0.86509836,0.86509836]]))
mish ::
  forall shape dtype device.
  ( StandardFloatingPointDTypeValidation device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    shape ~ Broadcast shape shape
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape
mish :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(StandardFloatingPointDTypeValidation device dtype,
 BasicArithmeticDTypeIsValid device dtype,
 shape ~ Broadcast shape shape) =>
Tensor device dtype shape -> Tensor device dtype shape
mish = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul forall (m :: Size) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tanh forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
softplus (Float
1 :: Float) Float
20

-- | sigmoid
--
-- >>> dtype &&& shape $ sigmoid (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sigmoid ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
sigmoid :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
sigmoid Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sigmoid_t Tensor device dtype shape
input

-- | sin
--
-- >>> dtype &&& shape $ sin (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sin ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
sin :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sin t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sin_t (forall a. a -> Wrap a
Wrap t
input)

-- | sinh
--
-- >>> dtype &&& shape $ sinh (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sinh ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
sinh :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sinh t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sinh_t (forall a. a -> Wrap a
Wrap t
input)

-- | cos
--
-- >>> dtype &&& shape $ cos (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
cos ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
cos :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
cos t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.cos_t (forall a. a -> Wrap a
Wrap t
input)

-- | sqrt
sqrt ::
  forall shape dtype device t.
  ( StandardFloatingPointDTypeValidation device dtype,
    IsUnnamed t device dtype shape
  ) =>
  -- | input
  t ->
  -- | output
  t
sqrt :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sqrt t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sqrt_t (forall a. a -> Wrap a
Wrap t
input)

-- | tanh
tanh ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
tanh :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tanh Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.tanh_t Tensor device dtype shape
input

-- | ConditionalReduction
--
-- >>> :kind! ConditionalReduction '[3,2] ReduceNone
-- ConditionalReduction '[3,2] ReduceNone :: [Natural]
-- = '[3, 2]
-- >>> :kind! ConditionalReduction '[3,2] ReduceMean
-- ConditionalReduction '[3,2] ReduceMean :: [Natural]
-- = '[]
type family ConditionalReduction (shape :: [Nat]) (reduction :: Reduction) :: [Nat] where
  ConditionalReduction shape ReduceNone = shape
  ConditionalReduction shape _ = '[]

class KnownReduction reduction where
  reductionVal :: Int

instance KnownReduction ReduceNone where
  reductionVal :: Int
reductionVal = Int
0

instance KnownReduction ReduceMean where
  reductionVal :: Int
reductionVal = Int
1

instance KnownReduction ReduceSum where
  reductionVal :: Int
reductionVal = Int
2

-- | binary cross entropy
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ binaryCrossEntropy @ReduceNone t t t
-- (Float,[2,2])
-- >>> dtype &&& shape $ binaryCrossEntropy @ReduceMean t t t
-- (Float,[])
-- >>> dtype &&& shape $ binaryCrossEntropy @ReduceSum t t t
-- (Float,[])
binaryCrossEntropy ::
  forall (reduction :: Reduction) shape shape' dtype device.
  ( KnownReduction reduction,
    shape' ~ ConditionalReduction shape reduction,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | weight
  Tensor device dtype shape ->
  -- | prediction
  Tensor device dtype shape ->
  -- | target
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
binaryCrossEntropy :: forall (reduction :: Reduction) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction,
 shape' ~ ConditionalReduction shape reduction,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype shape'
binaryCrossEntropy Tensor device dtype shape
weight Tensor device dtype shape
prediction Tensor device dtype shape
target =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.binary_cross_entropy_tttl
      Tensor device dtype shape
prediction
      Tensor device dtype shape
target
      Tensor device dtype shape
weight
      (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)

-- | mseLoss
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ mseLoss @ReduceNone t t
-- (Float,[2,2])
-- >>> dtype &&& shape $ mseLoss @ReduceMean t t
-- (Float,[])
-- >>> dtype &&& shape $ mseLoss @ReduceSum t t
-- (Float,[])
mseLoss ::
  forall (reduction :: Reduction) shape shape' dtype device.
  ( KnownReduction reduction,
    shape' ~ ConditionalReduction shape reduction,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | prediction
  Tensor device dtype shape ->
  -- | target
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
mseLoss :: forall (reduction :: Reduction) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction,
 shape' ~ ConditionalReduction shape reduction,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape'
mseLoss Tensor device dtype shape
prediction Tensor device dtype shape
target =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.mse_loss_ttl
      Tensor device dtype shape
prediction
      Tensor device dtype shape
target
      (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)

-- | softmax
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ softmax @0 t
-- (Float,[2,2])
-- >>> dtype &&& shape $ softmax @1 t
-- (Float,[2,2])
softmax ::
  forall dim shape dtype device.
  ( KnownNat dim,
    DimOutOfBoundCheck shape dim,
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
softmax :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
softmax Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.Managed.softmax_tls Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype)

-- | logSoftmax
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> dtype &&& shape $ logSoftmax @0 t
-- (Float,[2,2])
-- >>> dtype &&& shape $ logSoftmax @1 t
-- (Float,[2,2])
logSoftmax ::
  forall dim shape dtype device.
  ( KnownNat dim,
    DimOutOfBoundCheck shape dim,
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
logSoftmax :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
logSoftmax Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.Managed.log_softmax_tls Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype)

type family Square (shape :: [Nat]) :: [Nat] where
  Square (n : n : '[]) = '[n, n]
  Square (b : n : n : '[]) = '[b, n, n]
  Square _ = TypeError (Text "This shape must be square matrix or batch + square matrix.")

type family VectorOfSquare (shape :: [Nat]) :: [Nat] where
  VectorOfSquare (n : n : '[]) = '[n]
  VectorOfSquare (b : n : n : '[]) = '[b, n]
  VectorOfSquare _ = TypeError (Text "This shape must be square matrix or batch + square matrix.")

type family FstSquareDim (shape :: [Nat]) :: Nat where
  FstSquareDim (n : m : '[]) = n
  FstSquareDim (b : n : m : '[]) = n
  FstSquareDim _ = TypeError (Text "Can not get first dimention of matrix or batch + matrix.")

type family InverseShapeIsValid (device :: (D.DeviceType, Nat)) (shape :: [Nat]) :: Constraint where
  InverseShapeIsValid '( 'D.CPU, 0) _ = ()
  InverseShapeIsValid '( 'D.CUDA, _) shape = AllDimsPositive shape

type family InverseDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  InverseDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  InverseDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  InverseDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | inverse
-- TODO: if rank < n for any tensors in the batch, then this will not work. we can't decide this statically, but we should prevent runtime errors. therefore, return Maybe?
--
-- >>> t <- randn :: IO (CPUTensor 'D.Float '[3,2,2])
-- >>> dtype &&& shape $ inverse t
-- (Float,[3,2,2])
-- >>> t <- randn :: IO (CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ inverse t
-- (Float,[2,2])
inverse ::
  forall shape shape' dtype device.
  ( shape' ~ Square shape,
    InverseShapeIsValid device shape,
    InverseDTypeIsValid device dtype
  ) =>
  -- | inverse
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
inverse :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ Square shape, InverseShapeIsValid device shape,
 InverseDTypeIsValid device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape'
inverse Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.inverse_t Tensor device dtype shape
input

type family SymeigDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SymeigDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  SymeigDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  SymeigDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | symeig
-- Warning:
-- torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future PyTorch release.
-- The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
-- L, _ = torch.symeig(A, upper=upper)
-- should be replaced with
-- L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
-- and
-- L, V = torch.symeig(A, eigenvectors=True)
-- should be replaced with
-- L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,2,2])
-- >>> (eigenVals,eigenVecs) = symeig Upper t
-- >>> dtype &&& shape $ eigenVals -- Skip warning
-- ...
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> :t eigenVals
-- eigenVals :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[3,2,2])
-- >>> :t eigenVecs
-- eigenVecs :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2, 2]
-- >>> (eigenVals,eigenVecs) = symeig Lower t
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[3,2,2])
symeig ::
  forall shape shape' shape'' dtype device.
  ( shape' ~ VectorOfSquare shape,
    shape'' ~ Square shape,
    SymeigDTypeIsValid device dtype
  ) =>
  -- | upper or lower triagonal
  Tri ->
  -- | input
  Tensor device dtype shape ->
  -- | eigenvalues and eigenvectors
  ( Tensor device dtype shape',
    Tensor device dtype shape''
  )
symeig :: forall (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape' ~ VectorOfSquare shape, shape'' ~ Square shape,
 SymeigDTypeIsValid device dtype) =>
Tri
-> Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device dtype shape'')
symeig Tri
upper Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.symeig_tbb Tensor device dtype shape
input Bool
True Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | symeigvalues
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,2,2])
-- >>> eigenVals = symeigvalues Upper t
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> :t eigenVals
-- eigenVals :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]
symeigvalues ::
  forall shape shape' dtype device.
  ( shape' ~ VectorOfSquare shape,
    SymeigDTypeIsValid device dtype
  ) =>
  -- | upper or lower triagonal
  Tri ->
  -- | input
  Tensor device dtype shape ->
  Tensor device dtype shape'
symeigvalues :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ VectorOfSquare shape, SymeigDTypeIsValid device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
symeigvalues Tri
upper Tensor device dtype shape
input = forall a b. (a, b) -> a
fst forall (shape'' :: [Nat]).
(Tensor device dtype shape', Tensor device dtype shape'')
symeig'
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper
    symeig' :: (Tensor device dtype shape', Tensor device dtype shape'')
    symeig' :: forall (shape'' :: [Nat]).
(Tensor device dtype shape', Tensor device dtype shape'')
symeig' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.symeig_tbb Tensor device dtype shape
input Bool
False Bool
boolUpper

data EigenVectors = EnableEigenVectors | DisableEigenVectors

class KnownEigenVectors a where
  enableEigenVectors :: Bool

instance KnownEigenVectors EnableEigenVectors where
  enableEigenVectors :: Bool
enableEigenVectors = Bool
True

instance KnownEigenVectors DisableEigenVectors where
  enableEigenVectors :: Bool
enableEigenVectors = Bool
False

type family ConditionalEigenVectors (eigenvectors :: EigenVectors) (n :: Nat) :: [Nat] where
  ConditionalEigenVectors EnableEigenVectors n = '[n, n]
  ConditionalEigenVectors DisableEigenVectors _ = '[0]

type family EigDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  EigDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  EigDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  EigDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | eig
-- Warning:
-- torch.eig is deprecated in favor of torch.linalg.eig and will be removed in a future PyTorch release.
-- torch.linalg.eig returns complex tensors of dtype cfloat or cdouble rather than real tensors mimicking complex tensors.
-- L, _ = torch.eig(A)
-- should be replaced with
-- L_complex = torch.linalg.eigvals(A)
-- and
-- L, V = torch.eig(A, eigenvectors=True)
-- should be replaced with
-- L_complex, V_complex = torch.linalg.eig(A) (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,3])
-- >>> (eigenVals,eigenVecs) = eig @EnableEigenVectors t
-- >>> dtype &&& shape $ eigenVals -- Skip warning
-- ...
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> :t eigenVals
-- eigenVals :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[3,3])
-- >>> :t eigenVecs
-- eigenVecs :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 3]
-- >>> (eigenVals,eigenVecs) = eig @DisableEigenVectors t
-- >>> dtype &&& shape $ eigenVals
-- (Float,[3,2])
-- >>> dtype &&& shape $ eigenVecs
-- (Float,[0])
-- >>> :t eigenVecs
-- eigenVecs :: Tensor '( 'D.CPU, 0) 'D.Float '[0]
eig ::
  forall eigenvectors n shape dtype device.
  ( KnownNat n,
    KnownEigenVectors eigenvectors,
    shape ~ ConditionalEigenVectors eigenvectors n,
    EigDTypeIsValid device dtype
  ) =>
  -- | input matrix
  Tensor device dtype '[n, n] ->
  -- | eigenvalues and eigenvectors
  ( Tensor device dtype '[n, 2],
    Tensor device dtype shape
  )
eig :: forall (eigenvectors :: EigenVectors) (n :: Nat) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownEigenVectors eigenvectors,
 shape ~ ConditionalEigenVectors eigenvectors n,
 EigDTypeIsValid device dtype) =>
Tensor device dtype '[n, n]
-> (Tensor device dtype '[n, 2], Tensor device dtype shape)
eig Tensor device dtype '[n, n]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.eig_tb Tensor device dtype '[n, n]
input (forall {k} (a :: k). KnownEigenVectors a => Bool
enableEigenVectors @eigenvectors)

type family SVDShapes (shape :: [Nat]) (reduced :: ReducedSVD) :: ([Nat], [Nat], [Nat]) where
  SVDShapes '[0, n] 'ThinSVD = '( '[0, 0], '[0], '[n, n])
  SVDShapes '[m, n] 'ThinSVD = '( '[m, Min m n], '[Min m n], '[n, Min m n])
  SVDShapes '[m, n] 'FullSVD = '( '[m, m], '[Min m n], '[n, n])
  SVDShapes '[b, 0, n] 'ThinSVD = '( '[b, 0, 0], '[b, 0], '[b, n, n])
  SVDShapes '[b, m, n] 'ThinSVD = '( '[b, m, Min m n], '[b, Min m n], '[b, n, Min m n])
  SVDShapes '[b, m, n] 'FullSVD = '( '[b, m, m], '[b, Min m n], '[b, n, n])
  SVDShapes _ _ = TypeError (Text "A singular value decomposition can only be computed for 2D matrices for at most one batch dimension.")

data ReducedSVD = ThinSVD | FullSVD

class KnownReducedSVD (reduced :: ReducedSVD) where
  reducedSVD :: Bool

instance KnownReducedSVD ThinSVD where
  reducedSVD :: Bool
reducedSVD = Bool
True

instance KnownReducedSVD FullSVD where
  reducedSVD :: Bool
reducedSVD = Bool
False

type family SVDDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SVDDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  SVDDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  SVDDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | Singular Value Decomposition
-- TODO: When `compute_uv` is `False`, backward cannot be performed since `u` and `v` from the forward pass are required for the backward operation. There is no way to encode in the types at this point in time. Thus, only `True` is supported currently.
--
-- This function returns a tuple `(u, s, v)`
-- which is the singular value decomposition of a input real matrix
-- or batches of real matrices input such that
-- `input = U×diag(S)×V^T`.
--
-- >>> a <- randn :: IO (CPUTensor 'D.Float '[3, 5])
-- >>> (u, s, v) = svd @'ThinSVD a
-- >>> dtype &&& shape $ u
-- (Float,[3,3])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[5,3])
-- >>> (u, s, v) = svd @'FullSVD a
-- >>> dtype &&& shape $ u
-- (Float,[3,3])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[5,5])
-- >>> a <- randn :: IO (CPUTensor 'D.Float '[5, 3])
-- >>> (u, s, v) = svd @'ThinSVD a
-- >>> dtype &&& shape $ u
-- (Float,[5,3])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[3,3])
-- >>> (u, s, v) = svd @'FullSVD a
-- >>> dtype &&& shape $ u
-- (Float,[5,5])
-- >>> dtype &&& shape $ s
-- (Float,[3])
-- >>> dtype &&& shape $ v
-- (Float,[3,3])
svd ::
  forall reduced shape shapeU shapeS shapeV dtype device.
  ( KnownReducedSVD reduced,
    '(shapeU, shapeS, shapeV) ~ SVDShapes shape reduced,
    SVDDTypeIsValid device dtype
  ) =>
  -- | (batched) input real matrix
  Tensor device dtype shape ->
  -- | (batched) output tuple of `u`, `s`, and `v`
  ( Tensor device dtype shapeU,
    Tensor device dtype shapeS,
    Tensor device dtype shapeV
  )
svd :: forall (reduced :: ReducedSVD) (shape :: [Nat]) (shapeU :: [Nat])
       (shapeS :: [Nat]) (shapeV :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownReducedSVD reduced,
 '(shapeU, shapeS, shapeV) ~ SVDShapes shape reduced,
 SVDDTypeIsValid device dtype) =>
Tensor device dtype shape
-> (Tensor device dtype shapeU, Tensor device dtype shapeS,
    Tensor device dtype shapeV)
svd Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.Managed.svd_tbb Tensor device dtype shape
input (forall (reduced :: ReducedSVD). KnownReducedSVD reduced => Bool
reducedSVD @reduced) Bool
True

type family CholeskyDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  CholeskyDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  CholeskyDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  CholeskyDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | cholesky
-- TODO: cholesky can throw if the input is not positive-definite.
-- Computes the Cholesky decomposition of a symmetric positive-definite matrix.
-- The operation supports batching.
--
-- Warning:
-- torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future PyTorch release.
-- L = torch.cholesky(A)
-- should be replaced with
-- L = torch.linalg.cholesky(A)
-- and
-- U = torch.cholesky(A, upper=True)
-- should be replaced with
-- U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[2,2])
-- >>> u = cholesky Upper (t `matmul` transpose2D t) -- Skip warning
-- ...
-- >>> dtype &&& shape $ u
-- (Float,[2,2])
-- >>> :t u
-- u :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2]
cholesky ::
  forall shape shape' dtype device.
  ( shape' ~ Square shape,
    CholeskyDTypeIsValid device dtype
  ) =>
  -- | indicate whether to return an upper or lower triangular matrix.
  Tri ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
cholesky :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ Square shape, CholeskyDTypeIsValid device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
cholesky Tri
upper Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_tb Tensor device dtype shape
input Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | choleskyInverse
-- Computes the inverse of a symmetric positive-definite matrix
-- using its Cholesky factor, returned, e.g., by `cholesky`.
-- Unlike `cholesky`, this operation does not support batching.
-- The inverse is computed using the LAPACK routine `?potri`.
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[2,2])
-- >>> tri = Upper
-- >>> u = cholesky tri (t `matmul` transpose2D t)
-- >>> dtype &&& shape $ choleskyInverse tri u
-- (Float,[2,2])
choleskyInverse ::
  forall n dtype device.
  ( 1 <= n,
    CholeskyDTypeIsValid device dtype
  ) =>
  -- | decides whether the upper or the lower triangular part of the input tensor is used
  Tri ->
  -- | the input 2-D tensor `u`, an upper or lower triangular Cholesky factor
  Tensor device dtype '[n, n] ->
  -- | the output 2-D tensor
  Tensor device dtype '[n, n]
choleskyInverse :: forall (n :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= n, CholeskyDTypeIsValid device dtype) =>
Tri -> Tensor device dtype '[n, n] -> Tensor device dtype '[n, n]
choleskyInverse Tri
upper Tensor device dtype '[n, n]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_inverse_tb Tensor device dtype '[n, n]
input Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | choleskySolve
-- Solves the system of linear equations represented by `a c = b`
-- using the Cholesky factor matrix `u` of `a` (returned, e.g., by `cholesky`),
-- where `a` is a positive semidefinite matrix.
-- The operation supports batching.
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[3,3])
-- >>> a = t `matmul` transpose2D t
-- >>> b <- rand :: IO (CPUTensor 'D.Float '[3,2])
-- >>> tri = Upper
-- >>> u = cholesky tri a
-- >>> dtype &&& shape $ choleskySolve tri b u
-- (Float,[3,2])
choleskySolve ::
  forall m_k m_m dtype device.
  ( Square m_m ~ m_m,
    FstSquareDim m_m ~ FstSquareDim m_k,
    1 <= FstSquareDim m_m,
    CholeskyDTypeIsValid device dtype
  ) =>
  -- | decides whether the upper or the lower triangular part of the input tensor `u` is used
  Tri ->
  -- | the (batched) RHS tensor `b`
  Tensor device dtype m_k ->
  -- | the (batched) input 2-D tensor `u`, an upper or lower triangular Cholesky factor
  Tensor device dtype m_m ->
  -- | the (batched) output 2-D tensor
  Tensor device dtype m_k
choleskySolve :: forall (m_k :: [Nat]) (m_m :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(Square m_m ~ m_m, FstSquareDim m_m ~ FstSquareDim m_k,
 1 <= FstSquareDim m_m, CholeskyDTypeIsValid device dtype) =>
Tri
-> Tensor device dtype m_k
-> Tensor device dtype m_m
-> Tensor device dtype m_k
choleskySolve Tri
upper Tensor device dtype m_k
b Tensor device dtype m_m
u =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_solve_ttb Tensor device dtype m_k
b Tensor device dtype m_m
u Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

type family SolveDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  SolveDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  SolveDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  SolveDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | solve
-- Solves the system of linear equations represented by `a c = b` and also returns the LU decomposition of `a`.
-- `a` has to be a positive semidefinite matrix.
-- The operation supports batching.
--
-- Warning:
-- torch.solve is deprecated in favor of torch.linalg.solveand will be removed in a future PyTorch release.
-- torch.linalg.solve has its arguments reversed and does not return the LU factorization.
-- To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
-- X = torch.solve(B, A).solution
-- should be replaced with
-- X = torch.linalg.solve(A, B) (function operator())
--
-- >>> t <- rand :: IO (CPUTensor 'D.Float '[10,10])
-- >>> a = t `matmul` transpose2D t
-- >>> b <- rand :: IO (CPUTensor 'D.Float '[10,3])
-- >>> (c,lu) = solve b a
-- >>> dtype &&& shape $ c -- Skip warning
-- ...
-- >>> dtype &&& shape $ c
-- (Float,[10,3])
-- >>> dtype &&& shape $ lu
-- (Float,[10,10])
-- >>> :t c
-- c :: Tensor '( 'D.CPU, 0) 'D.Float '[10, 3]
-- >>> :t lu
-- lu :: Tensor '( 'D.CPU, 0) 'D.Float '[10, 10]
solve ::
  forall m_k m_m dtype device.
  ( Square m_m ~ m_m,
    FstSquareDim m_m ~ FstSquareDim m_k,
    1 <= FstSquareDim m_m,
    SolveDTypeIsValid device dtype
  ) =>
  -- | the (batched) RHS tensor `b`
  Tensor device dtype m_k ->
  -- | the (batched) positive semidefinite matrix `a`
  Tensor device dtype m_m ->
  -- | the (batched) outputs c and lu
  ( Tensor device dtype m_k,
    Tensor device dtype m_m
  )
solve :: forall (m_k :: [Nat]) (m_m :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(Square m_m ~ m_m, FstSquareDim m_m ~ FstSquareDim m_k,
 1 <= FstSquareDim m_m, SolveDTypeIsValid device dtype) =>
Tensor device dtype m_k
-> Tensor device dtype m_m
-> (Tensor device dtype m_k, Tensor device dtype m_m)
solve Tensor device dtype m_k
b Tensor device dtype m_m
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.solve_tt Tensor device dtype m_k
b Tensor device dtype m_m
a

-- | geqrf
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- `geqrf` computes a QR decomposition of the given `input` matrix,
-- but without constructing `Q` and `R` as explicit separate matrices.
-- Rather, this function directly calls the underlying LAPACK function `?geqrf`
-- which produces a tuple `(a, tau)` of intermediate results as defined in
-- the LAPACK documentation for `?geqrf`.
--
-- You can use `orgqr` on `(a, tau)` to compute the real orthogonal matrix `Q`,
-- but in general you may just want to use `qr` instead.
--
-- See the LAPACK documentation for `?geqrf` for further details,
-- https://software.intel.com/en-us/node/521004.
--
-- >>> (a, tau) = geqrf (ones :: CPUTensor 'D.Float '[3,4])
-- >>> dtype &&& shape $ a
-- (Float,[3,4])
-- >>> dtype &&& shape $ tau
-- (Float,[3])
-- >>> (a, tau) = geqrf (ones :: CPUTensor 'D.Float '[4,3])
-- >>> dtype &&& shape $ a
-- (Float,[4,3])
-- >>> dtype &&& shape $ tau
-- (Float,[3])
geqrf ::
  forall m n dtype device.
  -- | input matrix
  Tensor device dtype '[m, n] ->
  -- | tuple `(a, tau)` of output matrices
  ( Tensor device dtype '[m, n],
    Tensor device dtype '[Min m n]
  )
geqrf :: forall (m :: Nat) (n :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[m, n]
-> (Tensor device dtype '[m, n], Tensor device dtype '[Min m n])
geqrf Tensor device dtype '[m, n]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.geqrf_t Tensor device dtype '[m, n]
input

-- | orgqr
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- Computes the orthogonal matrix `Q` of a QR factorization
-- from the `(a, tau)` tuple returned by `geqrf`.
--
-- This directly calls the underlying LAPACK function `?orgqr`.
-- See the LAPACK documentation for `?orgqr` for further details,
-- https://software.intel.com/en-us/mkl-developer-reference-c-orgqr.
--
-- When libtorch-1.7, this function behavior is changed.
-- First dimention should be greater than second dimention.
--
-- >>> dtype &&& shape $ orgqr (ones :: CPUTensor 'D.Float '[4,3]) (ones :: CPUTensor 'D.Float '[3])
-- (Float,[4,3])
orgqr ::
  forall m n dtype device.
  ( KnownNat n,
    KnownNat m,
    n <= m
  ) =>
  Tensor device dtype '[m, n] ->
  Tensor device dtype '[n] ->
  Tensor device dtype '[m, n]
orgqr :: forall (m :: Nat) (n :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, n <= m) =>
Tensor device dtype '[m, n]
-> Tensor device dtype '[n] -> Tensor device dtype '[m, n]
orgqr Tensor device dtype '[m, n]
a Tensor device dtype '[n]
tau = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.orgqr_tt Tensor device dtype '[m, n]
a Tensor device dtype '[n]
tau

-- | sign
-- works for all dtypes
--
-- >>> dtype &&& shape $ sign (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
sign ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
sign :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
sign Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sign_t Tensor device dtype shape
input

type family SetValue (shape :: [Nat]) (i :: Nat) (j :: Nat) :: [Nat] where
  SetValue '[] _ _ = '[]
  SetValue (x : xs) 0 j = j : xs
  SetValue (x : xs) i j = x : SetValue xs (i -1) j

type family GetValue (shape :: [Nat]) (i :: Nat) :: Nat where
  GetValue '[] _ = TypeError (Text "Can not find a element in the list.")
  GetValue (x : xs) 0 = x
  GetValue (x : xs) i = GetValue xs (i -1)

-- | Transpose
--
-- >>> :kind! Transpose '[3,2] 0 1
-- Transpose '[3,2] 0 1 :: [Natural]
-- = '[2, 3]
-- >>> :kind! Transpose '[3,2,1] 1 2
-- Transpose '[3,2,1] 1 2 :: [Natural]
-- = '[3, 1, 2]
type family Transpose (shape :: [Nat]) (dim0 :: Nat) (dim1 :: Nat) :: [Nat] where
  Transpose s d0 d1 = (SetValue (SetValue s d0 (GetValue s d1)) d1 (GetValue s d0))

-- | transpose
-- See "../../../../deps/pytorch/aten/src/ATen/native/TensorShape.cpp".
--
-- >>> dtype &&& shape $ transpose @0 @1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2,3])
-- >>> dtype &&& shape $ transpose @0 @1 (ones :: CPUTensor 'D.Float '[3,2,1])
-- (Float,[2,3,1])
-- >>> dtype &&& shape $ transpose @1 @2 (ones :: CPUTensor 'D.Float '[3,2,1])
-- (Float,[3,1,2])
transpose ::
  forall n m shape shape' dtype device.
  ( KnownNat n,
    KnownNat m,
    shape' ~ Transpose shape n m
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
transpose :: forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.transpose_tll Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @n) (forall (n :: Nat). KnownNat n => Int
natValI @m)

-- | transpose2d, special case for a 2D tensor
--
-- >>> dtype &&& shape $ transpose2D (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2,3])
transpose2D ::
  forall (i :: Nat) (j :: Nat) dtype device.
  -- | input
  Tensor device dtype '[i, j] ->
  -- | output
  Tensor device dtype '[j, i]
transpose2D :: forall (i :: Nat) (j :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[i, j] -> Tensor device dtype '[j, i]
transpose2D = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @0 @1

class KnownTri (tri :: Tri) where
  triVal :: Tri

instance KnownTri Upper where
  triVal :: Tri
triVal = Tri
Upper

instance KnownTri Lower where
  triVal :: Tri
triVal = Tri
Lower

type family DiagSize (tri :: Tri) (index :: Nat) (m :: Nat) (n :: Nat) :: Nat where
  DiagSize 'Upper i m n =
    If
      (i <=? n)
      (Min m (n - i))
      ( TypeError
          ( Text "For a matrix with shape "
              :<>: ShowType '[m, n]
              :<>: Text ", the maximum index for an upper diagonal is "
              :<>: ShowType n
              :<>: Text ", but asked for index "
              :<>: ShowType i
          )
      )
  DiagSize 'Lower i m n =
    If
      (i <=? m)
      (Min (m - i) n)
      ( TypeError
          ( Text "For a matrix with shape "
              :<>: ShowType '[m, n]
              :<>: Text ", the maximum index for a lower diagonal is "
              :<>: ShowType m
              :<>: Text ", but asked for index "
              :<>: ShowType i
          )
      )

type family DiagShape (tri :: Tri) (index :: Nat) (shape :: [Nat]) :: [Nat] where
  DiagShape _ i '[n] = '[n + i, n + i]
  DiagShape tri i '[m, n] = '[DiagSize tri i m n]
  DiagShape _ _ shape =
    TypeError
      ( Text "The input must be a matrix or a vector, but it has "
          :<>: ShowType (ListLength shape)
          :<>: Text " dimensions."
      )

-- | diag
--
-- >>> dtype &&& shape $ diag @'Upper @0 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2])
-- >>> dtype &&& shape $ diag @'Upper @1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[1])
-- >>> dtype &&& shape $ diag @'Lower @1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[2])
diag ::
  forall tri index shape shape' device dtype.
  ( KnownTri tri,
    KnownNat index,
    StandardDTypeValidation device dtype,
    shape' ~ DiagShape tri index shape
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
diag :: forall (tri :: Tri) (index :: Nat) (shape :: [Nat])
       (shape' :: [Nat]) (device :: (DeviceType, Nat)) (dtype :: DType).
(KnownTri tri, KnownNat index,
 StandardDTypeValidation device dtype,
 shape' ~ DiagShape tri index shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
diag Tensor device dtype shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.tensor_diag_l Tensor device dtype shape
t forall a b. (a -> b) -> a -> b
$
    case forall (tri :: Tri). KnownTri tri => Tri
triVal @tri of
      Tri
Upper -> forall (n :: Nat). KnownNat n => Int
natValI @index
      Tri
Lower -> - forall (n :: Nat). KnownNat n => Int
natValI @index

-- | all
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.all.
--
-- >>> t = all (fromJust [False, False] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- False
--
-- >>> t = all (fromJust [False, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- False
--
-- >>> t = all (fromJust [True, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- True
all ::
  forall shape device.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool '[]
all :: forall (shape :: [Nat]) (device :: (DeviceType, Nat)).
Tensor device 'Bool shape -> Tensor device 'Bool '[]
all Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.all_t Tensor device 'Bool shape
input

-- | any
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.any.
--
-- >>> t = any (fromJust [False, False] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- False
--
-- >>> t = any (fromJust [False, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- True
--
-- >>> t = any (fromJust [True, True] :: CPUTensor 'D.Bool '[2])
-- >>> toInt t == 1
-- True
any ::
  forall shape device.
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool '[]
any :: forall (shape :: [Nat]) (device :: (DeviceType, Nat)).
Tensor device 'Bool shape -> Tensor device 'Bool '[]
any Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.any_t Tensor device 'Bool shape
input

data KeepOrDropDim = KeepDim | DropDim

class KnownKeepOrDropDim keepOrDropDim where
  keepOrDropDimVal :: Bool

instance KnownKeepOrDropDim KeepDim where
  keepOrDropDimVal :: Bool
keepOrDropDimVal = Bool
True

instance KnownKeepOrDropDim DropDim where
  keepOrDropDimVal :: Bool
keepOrDropDimVal = Bool
False

type family ConditionalDropDimension (shape :: [Nat]) (dim :: Nat) (keepOrDropDim :: KeepOrDropDim) :: [Nat] where
  ConditionalDropDimension '[] _ _ = TypeError (Text "The specified dimension is not available.")
  ConditionalDropDimension (x : xs) 0 KeepDim = 1 ': xs
  ConditionalDropDimension (x : xs) 0 DropDim = xs
  ConditionalDropDimension (x : xs) i keepOrDropDim = x ': ConditionalDropDimension xs (i - 1) keepOrDropDim

-- | allDim
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.all.
--
-- >>> t = fromJust [[True, True], [True, False], [True, True], [True, True]] :: CPUTensor 'D.Bool '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ allDim @1 @DropDim t
-- (Bool,([4],[True,False,True,True]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ allDim @1 @KeepDim t
-- (Bool,([4,1],[[True],[False],[True],[True]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ allDim @0 @DropDim t
-- (Bool,([2],[True,False]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ allDim @0 @KeepDim t
-- (Bool,([1,2],[[True,False]]))
allDim ::
  forall dim keepOrDropDim shape' shape device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim
  ) =>
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape'
allDim :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device 'Bool shape -> Tensor device 'Bool shape'
allDim Tensor device 'Bool shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.all_tlb Tensor device 'Bool shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | anyDim
-- See https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.any.
--
-- >>> t = fromJust [[True, True], [True, False], [True, True], [True, True]] :: CPUTensor 'D.Bool '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ anyDim @1 @DropDim t
-- (Bool,([4],[True,True,True,True]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ anyDim @1 @KeepDim t
-- (Bool,([4,1],[[True],[True],[True],[True]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Bool]) $ anyDim @0 @DropDim t
-- (Bool,([2],[True,True]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Bool]]) $ anyDim @0 @KeepDim t
-- (Bool,([1,2],[[True,True]]))
anyDim ::
  forall dim keepOrDropDim shape' shape device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim
  ) =>
  -- | input
  Tensor device 'D.Bool shape ->
  -- | output
  Tensor device 'D.Bool shape'
anyDim :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device 'Bool shape -> Tensor device 'Bool shape'
anyDim Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.any_tlb Tensor device 'Bool shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | dropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: get rid of IO by exposing the RNG state
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,2]
-- >>> t' <- dropout 0.5 False t
-- >>> dtype &&& shape $ t'
-- (Float,[3,2])
-- >>> t'' <- dropout 0.5 False t
-- >>> t ==. t''
-- Tensor Bool [3,2] [[ 1,  1],
--                    [ 1,  1],
--                    [ 1,  1]]
-- >>> t''' <- dropout 0.0 True t
-- >>> t ==. t'''
-- Tensor Bool [3,2] [[ 1,  1],
--                    [ 1,  1],
--                    [ 1,  1]]
-- >>> t'''' <- dropout 1.0 True t
-- >>> t''''
-- Tensor Float [3,2] [[ 0.0000,  0.0000],
--                     [ 0.0000,  0.0000],
--                     [ 0.0000,  0.0000]]
dropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  IO (Tensor device dtype shape)
dropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropout Double
p Bool
train Tensor device dtype shape
input = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | featureDropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: why not IO?
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> c = featureDropout 0.1 True (ones :: CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ c
-- (Float,[2,2])
featureDropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
featureDropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
featureDropout Double
p Bool
train Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.feature_dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | alphaDropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: why not IO?
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> c = alphaDropout 0.1 True (ones :: CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ c
-- (Float,[2,2])
alphaDropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
alphaDropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
alphaDropout Double
p Bool
train Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.alpha_dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | featureAlphaDropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: why not IO?
-- TODO: can we use D.Scalar for the dropout probability?
--
-- >>> c = featureAlphaDropout 0.1 True (ones :: CPUTensor 'D.Float '[2,2])
-- >>> dtype &&& shape $ c
-- (Float,[2,2])
featureAlphaDropout ::
  forall shape dtype device.
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
featureAlphaDropout :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
featureAlphaDropout Double
p Bool
train Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.feature_alpha_dropout_tdb Tensor device dtype shape
input Double
p Bool
train

-- | acos
--
-- >>> dtype &&& shape $ acos (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
acos ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
acos :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
acos Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.acos_t Tensor device dtype shape
input

-- | avgPool1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = avgPool1d @1 @1 @0 (ones :: CPUTensor 'D.Float '[1,3,4])
-- >>> shape t
-- [1,3,4]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4]
avgPool1d ::
  forall
    kernelSize
    stride
    padding
    channelSize
    inputSize
    batchSize
    outputSize
    dtype
    device.
  ( All KnownNat '[kernelSize, stride, padding, channelSize, inputSize, batchSize],
    ConvSideCheck inputSize kernelSize stride padding outputSize
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize]
avgPool1d :: forall (kernelSize :: Nat) (stride :: Nat) (padding :: Nat)
       (channelSize :: Nat) (inputSize :: Nat) (batchSize :: Nat)
       (outputSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[kernelSize, stride, padding, channelSize, inputSize, batchSize],
 ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
avgPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.avg_pool1d_tlllbb
      Tensor device dtype '[batchSize, channelSize, inputSize]
input
      (forall (n :: Nat). KnownNat n => Int
natValI @kernelSize)
      (forall (n :: Nat). KnownNat n => Int
natValI @stride)
      (forall (n :: Nat). KnownNat n => Int
natValI @padding)
      Bool
False
      Bool
True

-- | adaptiveAvgPool1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = adaptiveAvgPool1d @8 (ones :: CPUTensor 'D.Float '[1,3,16])
-- >>> shape t
-- [1,3,8]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8]
adaptiveAvgPool1d ::
  forall outputSize channelSize inputSize batchSize dtype device.
  (All KnownNat '[channelSize, inputSize, batchSize, outputSize]) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize]
adaptiveAvgPool1d :: forall (outputSize :: Nat) (channelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All KnownNat '[channelSize, inputSize, batchSize, outputSize] =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
adaptiveAvgPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool1d_tl Tensor device dtype '[batchSize, channelSize, inputSize]
input (forall (n :: Nat). KnownNat n => Int
natValI @outputSize)

-- | adaptiveMaxPool1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> tt = adaptiveMaxPool1d @8 (ones :: CPUTensor 'D.Float '[1,3,16])
-- >>> shape . fst $ tt
-- [1,3,8]
-- >>> :t tt
-- tt
--   :: (Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8],
--       Tensor '( 'D.CPU, 0) 'D.Int64 '[1, 3, 8])
adaptiveMaxPool1d ::
  forall outputSize channelSize inputSize batchSize dtype device.
  (All KnownNat '[channelSize, inputSize, batchSize, outputSize]) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize] ->
  -- | output
  ( Tensor device dtype '[batchSize, channelSize, outputSize],
    Tensor device 'D.Int64 '[batchSize, channelSize, outputSize]
  )
adaptiveMaxPool1d :: forall (outputSize :: Nat) (channelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All KnownNat '[channelSize, inputSize, batchSize, outputSize] =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> (Tensor device dtype '[batchSize, channelSize, outputSize],
    Tensor device 'Int64 '[batchSize, channelSize, outputSize])
adaptiveMaxPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.adaptive_max_pool1d_tl Tensor device dtype '[batchSize, channelSize, inputSize]
input (forall (n :: Nat). KnownNat n => Int
natValI @outputSize)

-- | addmv
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for beta and alpha?
--
-- >>> t = addmv 1 1 (ones :: CPUTensor 'D.Float '[3,2]) (zeros :: CPUTensor 'D.Float '[2]) (ones :: CPUTensor 'D.Float '[])
-- >>> dtype &&& shape $ t
-- (Float,[3])
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[3]
addmv ::
  forall shape' shape n m dtype device.
  ( KnownNat n,
    KnownNat m,
    shape' ~ Broadcast shape '[n]
  ) =>
  -- | beta
  Float ->
  -- | alpha
  Float ->
  -- | matrix
  Tensor device dtype '[n, m] ->
  -- | vector
  Tensor device dtype '[m] ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
addmv :: forall (shape' :: [Nat]) (shape :: [Nat]) (n :: Nat) (m :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Broadcast shape '[n]) =>
Float
-> Float
-> Tensor device dtype '[n, m]
-> Tensor device dtype '[m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
addmv Float
beta Float
alpha Tensor device dtype '[n, m]
mat Tensor device dtype '[m]
vec Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.addmv_tttss Tensor device dtype shape
input Tensor device dtype '[n, m]
mat Tensor device dtype '[m]
vec Float
beta Float
alpha

-- affine_grid_generator :: Tensor device dtype shape -> [Int] -> Tensor device dtype shape
-- affine_grid_generator _theta _size = unsafePerformIO $ (ATen.cast2 ATen.Managed.affine_grid_generator_tl) _theta _size

-- | allclose
--
-- >>> allclose 0.1 0.1 True (ones :: CPUTensor 'D.Float '[3,3]) (ones :: CPUTensor 'D.Float '[3,3])
-- True
allclose ::
  forall shape dtype device.
  -- | relative tolerance
  Double ->
  -- | absolute tolerance
  Double ->
  -- | whether or not NaN equals NaN
  Bool ->
  -- | input tensor
  Tensor device dtype shape ->
  -- | other input tensor
  Tensor device dtype shape ->
  -- | output
  Bool
allclose :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Double
-> Bool
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Bool
allclose Double
rtol Double
atol Bool
equalNaN Tensor device dtype shape
input Tensor device dtype shape
other =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor -> CDouble -> CDouble -> CBool -> IO CBool
ATen.Managed.allclose_ttddb Tensor device dtype shape
input Tensor device dtype shape
other Double
rtol Double
atol Bool
equalNaN

-- | argmax
-- See https://pytorch.org/docs/stable/torch.html#torch.argmax.
--
-- >>> t = fromJust [[0, 1], [-1, 2], [0, 1], [0, -2]] :: CPUTensor 'D.Float '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmax @1 @DropDim t
-- (Int64,([4],[1,1,1,0]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmax @1 @KeepDim t
-- (Int64,([4,1],[[1],[1],[1],[0]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmax @0 @DropDim t
-- (Int64,([2],[0,1]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmax @0 @KeepDim t
-- (Int64,([1,2],[[0,1]]))
argmax ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device 'D.Int64 shape'
argmax :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device 'Int64 shape'
argmax Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argmax_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | argmin
-- See https://pytorch.org/docs/stable/torch.html#torch.argmin.
--
-- >>> t = fromJust [[0, 1], [-1, 2], [0, 1], [0, -2]] :: CPUTensor 'D.Float '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmin @1 @DropDim t
-- (Int64,([4],[0,0,0,1]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmin @1 @KeepDim t
-- (Int64,([4,1],[[0],[0],[0],[1]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Int]) $ argmin @0 @DropDim t
-- (Int64,([2],[1,3]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Int]]) $ argmin @0 @KeepDim t
-- (Int64,([1,2],[[1,3]]))
argmin ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
    StandardDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device 'D.Int64 shape'
argmin :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
 StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device 'Int64 shape'
argmin Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argmin_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- as_strided :: Tensor device dtype shape -> [Int] -> [Int] -> Int -> Tensor device dtype shape
-- as_strided _input _size _stride _storage_offset = unsafePerformIO $ (ATen.cast4 ATen.Managed.as_strided_tlll) _input _size _stride _storage_offset

-- | asin
--
-- >>> dtype &&& shape $ asin (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
asin ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  Tensor device dtype shape ->
  Tensor device dtype shape
asin :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
asin Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.asin_t Tensor device dtype shape
input

-- | atan
--
-- >>> dtype &&& shape $ atan (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
atan ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  Tensor device dtype shape ->
  Tensor device dtype shape
atan :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
atan Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.atan_t Tensor device dtype shape
input

-- | baddbmm
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = baddbmm 1 1 (ones :: CPUTensor 'D.Float '[5,3,2]) (zeros :: CPUTensor 'D.Float '[5,2,4]) (ones :: CPUTensor 'D.Float '[])
-- >>> dtype &&& shape $ t
-- (Float,[5,3,4])
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 3, 4]
baddbmm ::
  forall shape' shape batchSize n m k dtype device.
  ( KnownNat n,
    KnownNat m,
    KnownNat k,
    shape' ~ Broadcast shape '[batchSize, n, m]
  ) =>
  -- | beta
  Float ->
  -- | alpha
  Float ->
  -- | first batch
  Tensor device dtype '[batchSize, n, k] ->
  -- | second batch
  Tensor device dtype '[batchSize, k, m] ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
baddbmm :: forall (shape' :: [Nat]) (shape :: [Nat]) (batchSize :: Nat)
       (n :: Nat) (m :: Nat) (k :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, KnownNat k,
 shape' ~ Broadcast shape '[batchSize, n, m]) =>
Float
-> Float
-> Tensor device dtype '[batchSize, n, k]
-> Tensor device dtype '[batchSize, k, m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
baddbmm Float
beta Float
alpha Tensor device dtype '[batchSize, n, k]
batch1 Tensor device dtype '[batchSize, k, m]
batch2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.baddbmm_tttss Tensor device dtype shape
input Tensor device dtype '[batchSize, n, k]
batch1 Tensor device dtype '[batchSize, k, m]
batch2 Float
beta Float
alpha

-- batch_norm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Double -> Double -> Bool -> Tensor device dtype shape
-- batch_norm _input _weight _bias _running_mean _running_var _training _momentum _eps _cudnn_enabled = unsafePerformIO $ (ATen.cast9 ATen.Managed.batch_norm_tttttbddb) _input _weight _bias _running_mean _running_var _training _momentum _eps _cudnn_enabled

-- bilinear :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- bilinear _input1 _input2 _weight _bias = unsafePerformIO $ (ATen.cast4 ATen.Managed.bilinear_tttt) _input1 _input2 _weight _bias

-- binary_cross_entropy_with_logits :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- binary_cross_entropy_with_logits _input _target _weight _pos_weight _reduction = unsafePerformIO $ (ATen.cast5 ATen.Managed.binary_cross_entropy_with_logits_ttttl) _input _target _weight _pos_weight _reduction

-- bincount :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- bincount _input _weights _minlength = unsafePerformIO $ (ATen.cast3 ATen.Managed.bincount_ttl) _input _weights _minlength

-- | batched matrix multiplication
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ bmm (ones :: CPUTensor 'D.Float '[5,3,2]) (zeros :: CPUTensor 'D.Float '[5,2,4])
-- (Float,[5,3,4])
bmm ::
  forall batchSize n m k dtype device.
  -- | input
  Tensor device dtype '[batchSize, n, k] ->
  -- | other input
  Tensor device dtype '[batchSize, k, m] ->
  -- | output
  Tensor device dtype '[batchSize, n, m]
bmm :: forall (batchSize :: Nat) (n :: Nat) (m :: Nat) (k :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[batchSize, n, k]
-> Tensor device dtype '[batchSize, k, m]
-> Tensor device dtype '[batchSize, n, m]
bmm Tensor device dtype '[batchSize, n, k]
input Tensor device dtype '[batchSize, k, m]
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.bmm_tt Tensor device dtype '[batchSize, n, k]
input Tensor device dtype '[batchSize, k, m]
other

-- | BroadcastTensorsImpl
--
-- >>> type Ty = BroadcastTensorsImpl '[] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Nothing
-- >>> type Ty = BroadcastTensorsImpl '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1]] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Just '( '[2, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = BroadcastTensorsImpl '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1], Tensor '( 'D.CPU, 0) 'D.Float '[5, 1, 1]] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Just '( '[5, 2, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = BroadcastTensorsImpl '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1], Tensor '( 'D.CPU, 0) 'D.Float '[1, 5, 1]] 'Nothing
-- >>> :kind! Ty
-- Ty :: Maybe ([Natural], D.DType, (D.DeviceType, Natural))
-- = 'Nothing
type family BroadcastTensorsImpl (tensors :: [a]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
  BroadcastTensorsImpl '[] 'Nothing = 'Nothing
  BroadcastTensorsImpl '[] ('Just '(reverseShape, dtype, device)) = 'Just '(Reverse reverseShape, dtype, device)
  BroadcastTensorsImpl (Tensor device dtype shape ': tensors) 'Nothing = BroadcastTensorsImpl tensors ('Just '(Reverse shape, dtype, device))
  BroadcastTensorsImpl (Tensor device dtype shape ': tensors) ('Just '(reverseShape', dtype, device)) = BroadcastTensorsImpl tensors (MaybeTriple (ComputeBroadcast (Reverse shape) reverseShape') ('Just dtype) ('Just device))
  BroadcastTensorsImpl (Tensor device dtype shape ': _) ('Just _) = Nothing

type family BroadcastTensorsCheck (tensors :: [a]) (result :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: [a] where
  BroadcastTensorsCheck tensors 'Nothing =
    TypeError
      ( Text "Cannot broadcast tensors due to incompatible shapes and/or dtypes: "
          :<>: ShowType tensors
      )
  BroadcastTensorsCheck tensors ('Just '(shape, dtype, device)) = HReplicateR (ListLength tensors) (Tensor device dtype shape)

type BroadcastTensors tensors =
  BroadcastTensorsCheck tensors (BroadcastTensorsImpl tensors 'Nothing)

-- | broadcast tensors
-- TODO: broadcastTensors returns garbage data and is hence broken
-- See https://pytorch.org/docs/stable/_modules/torch/functional.html#broadcast_tensors.
--
-- >>> x = ones :: CPUTensor 'D.Float '[1, 3]
-- >>> y = ones :: CPUTensor 'D.Float '[2, 1]
-- >>> z = ones :: CPUTensor 'D.Float '[5, 1, 1]
--
-- -- >>> x' :. y' :. z' :. HNil = broadcastTensors (x :. y :. z :. HNil)
-- -- >>> :type x'
-- -- x' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2, 3]
-- -- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[[Float]]]) $ x'
-- -- >>> :type y'
-- -- y' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2, 3]
-- -- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[[Float]]]) $ y'
-- -- >>> :type z'
-- -- z' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2, 3]
-- -- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[[Float]]]) $ z'
broadcastTensors ::
  forall tensors tensors'.
  ( tensors' ~ BroadcastTensors tensors,
    ATen.Castable (HList tensors) [D.ATenTensor],
    ATen.Castable (HList tensors') [D.ATenTensor]
  ) =>
  -- | input list of tensors
  HList tensors ->
  -- | output list of tensors
  HList tensors'
broadcastTensors :: forall {k} (tensors :: [k]) (tensors' :: [k]).
(tensors' ~ BroadcastTensors tensors,
 Castable (HList tensors) [ForeignPtr Tensor],
 Castable (HList tensors') [ForeignPtr Tensor]) =>
HList tensors -> HList tensors'
broadcastTensors HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr TensorList -> IO (ForeignPtr TensorList)
ATen.Managed.broadcast_tensors_l HList tensors
tensors

type family CatImpl (dim :: Nat) (tensors :: [a]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
  CatImpl _ '[] acc = acc
  CatImpl dim (Tensor device dtype shape ': tensors) acc = CatImpl dim tensors (MaybeTriple (ComputeCatShape dim shape acc) (ComputeCatDType dtype acc) (ComputeCatDevice device acc))

type family ComputeCatShape (dim :: Nat) (shape :: [Nat]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe [Nat] where
  ComputeCatShape 0 (x ': xs) Nothing = Just (x ': xs)
  ComputeCatShape dim (x ': xs) Nothing = AppendToMaybe x (ComputeCatShape (dim - 1) xs Nothing)
  ComputeCatShape 0 (x ': xs) (Just '(y ': xs, _, _)) = Just ((x + y) ': xs)
  ComputeCatShape dim (x ': xs) (Just '(x ': ys, dtype, device)) = AppendToMaybe x (ComputeCatShape (dim - 1) xs (Just '(ys, dtype, device)))
  ComputeCatShape _ _ _ = Nothing

type family ComputeCatDType (dtype :: D.DType) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe D.DType where
  ComputeCatDType dtype Nothing = Just dtype
  ComputeCatDType dtype (Just '(_, dtype, _)) = Just dtype
  ComputeCatDType _ _ = Nothing

type family ComputeCatDevice (device :: (D.DeviceType, Nat)) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe (D.DeviceType, Nat) where
  ComputeCatDevice device Nothing = Just device
  ComputeCatDevice device (Just '(_, _, device)) = Just device
  ComputeCatDevice _ _ = Nothing

type family CatCheck (res :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: ([Nat], D.DType, (D.DeviceType, Nat)) where
  CatCheck 'Nothing = TypeError (Text "Concatenation impossible.")
  CatCheck ('Just '(shape, dtype, device)) = '(shape, dtype, device)

-- | Cat
--
-- >>> type Ty = Cat 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[1]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[1], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[1], Tensor '( 'D.CPU, 0) 'D.Float '[2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 3], Tensor '( 'D.CPU, 0) 'D.Float '[2, 3]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[3, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 1 '[Tensor '( 'D.CPU, 0) 'D.Float '[3, 1], Tensor '( 'D.CPU, 0) 'D.Float '[3, 2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[3, 3], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Cat 1 '[Tensor '( 'D.CPU, 0) 'D.Float '[2, 5, 4, 2], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 4, 2], Tensor '( 'D.CPU, 0) 'D.Float '[2, 3, 4, 2], Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 4, 2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[2, 10, 4, 2], 'D.Float, '( 'D.CPU, 0))
type Cat dim tensors = CatCheck (CatImpl dim tensors Nothing)

-- | cat
--
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> t' = cat @0 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([2,2],[[1.0,1.0],[1.0,1.0]]))
-- >>> t' = cat @1 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([2,2],[[1.0,1.0],[1.0,1.0]]))
-- >>> t' = cat @0 (t :. t :. t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[6, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([6,2],[[1.0,1.0],[1.0,1.0],[1.0,1.0],[1.0,1.0],[1.0,1.0],[1.0,1.0]]))
-- >>> t' = cat @1 (t :. t :. t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 6]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([2,6],[[1.0,1.0,1.0,1.0,1.0,1.0],[1.0,1.0,1.0,1.0,1.0,1.0]]))
cat ::
  forall dim shape dtype device tensors.
  ( KnownNat dim,
    '(shape, dtype, device) ~ Cat dim tensors,
    ATen.Castable (HList tensors) [D.ATenTensor]
  ) =>
  -- | input list of tensors
  HList tensors ->
  -- | output tensor
  Tensor device dtype shape
cat :: forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
cat HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.cat_ll HList tensors
tensors (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)

-- chain_matmul :: [Tensor device dtype shape] -> Tensor device dtype shape
-- chain_matmul _matrices = unsafePerformIO $ (ATen.cast1 ATen.Managed.chain_matmul_l) _matrices

type family ChunkImpl (chunkShapes :: Maybe [[Nat]]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) :: Maybe a where
  ChunkImpl (Just '[]) _ _ = Just '[]
  ChunkImpl (Just (shape ': shapes)) dtype device = AppendToMaybe (Tensor device dtype shape) (ChunkImpl (Just shapes) dtype device)
  ChunkImpl Nothing _ _ = Nothing

type family ChunkCheck (shape :: [Nat]) (dim :: Nat) (result :: Maybe a) :: a where
  ChunkCheck shape dim Nothing = DimOutOfBound shape dim
  ChunkCheck _ _ (Just result) = result

type family ComputeChunksChunkGo (n' :: Nat) (r :: Nat) (cmp :: Ordering) (cmp' :: Ordering) :: [Nat] where
  ComputeChunksChunkGo n' r GT _ = n' ': ComputeChunksChunkGo n' (r - n') (CmpNat (r - n') n') (CmpNat (r - n') 0)
  ComputeChunksChunkGo n' r EQ _ = n' ': ComputeChunksChunkGo n' (r - n') (CmpNat (r - n') n') (CmpNat (r - n') 0)
  ComputeChunksChunkGo n' r _ GT = '[r]
  ComputeChunksChunkGo n' _ _ _ = '[]

type family ComputeChunksChunkGo0 (n' :: Nat) (chunks :: Nat) :: [Nat] where
  ComputeChunksChunkGo0 _ 0 = '[]
  ComputeChunksChunkGo0 n' chunks = n' ': (ComputeChunksChunkGo0 n' (chunks - 1))

type family ComputeChunks (n :: Maybe Nat) (chunks :: Nat) :: Maybe [Nat] where
  ComputeChunks (Just n) chunks = Just (ComputeChunks' n chunks (Mod n chunks))
  ComputeChunks Nothing _ = Nothing

type family ComputeChunks' (n :: Nat) (chunks :: Nat) (m :: Nat) :: [Nat] where
  ComputeChunks' n chunks 0 = ComputeChunksChunkGo0 (Div n chunks) chunks
  ComputeChunks' n chunks _ = ComputeChunksChunkGo (Div (n + chunks - 1) chunks) n (CmpNat n (Div (n + chunks - 1) chunks)) (CmpNat n 0)

type family ChunkShapesImpl (chunks :: Maybe [Nat]) (dim :: Nat) (shape :: [Nat]) :: Maybe [[Nat]] where
  ChunkShapesImpl (Just (n ': ns)) dim shape = AppendToMaybe' (ReplaceDim dim shape n) (ChunkShapesImpl (Just ns) dim shape)
  ChunkShapesImpl (Just '[]) _ _ = Just '[]
  ChunkShapesImpl Nothing _ _ = Nothing

type ChunkShapes chunks dim shape = ChunkShapesImpl (ComputeChunks (ExtractDim dim shape) chunks) dim shape

type Chunk chunks dim shape dtype device = ChunkCheck shape dim (ChunkImpl (ChunkShapes chunks dim shape) dtype device)

-- | chunk
--
-- -- >>> :type chunk @3 @1 (ones :: CPUTensor 'D.Float '[2, 2])
-- -- chunk @3 @1 (ones :: CPUTensor 'D.Float '[2, 2])
-- --   :: HList
-- --        '[Tensor '( 'D.CPU, 0) 'D.Float '[2, 1],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[2, 1]]
-- >>> t0 :. t1 :. HNil = chunk @3 @1 (ones :: CPUTensor 'D.Float '[2, 2])
-- >>> dtype &&& shape $ t0
-- (Float,[2,1])
-- >>> dtype &&& shape $ t1
-- (Float,[2,1])
--
-- -- >>> :type chunk @3 @1 (ones :: CPUTensor 'D.Float '[1, 0, 3])
-- -- chunk @3 @1 (ones :: CPUTensor 'D.Float '[1, 0, 3])
-- --   :: HList
-- --        '[Tensor '( 'D.CPU, 0) 'D.Float '[1, 0, 3],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[1, 0, 3],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[1, 0, 3]]
-- >>> t0 :. t1 :. t2 :. HNil = chunk @3 @1 (ones :: CPUTensor 'D.Float '[1, 0, 3])
-- >>> dtype &&& shape $ t0
-- (Float,[1,0,3])
-- >>> dtype &&& shape $ t1
-- (Float,[1,0,3])
-- >>> dtype &&& shape $ t2
-- (Float,[1,0,3])
--
-- -- >>> :type chunk @6 @0 (ones :: CPUTensor 'D.Float '[19, 4])
-- -- chunk @6 @0 (ones :: CPUTensor 'D.Float '[19, 4])
-- --   :: HList
-- --        '[Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[4, 4],
-- --          Tensor '( 'D.CPU, 0) 'D.Float '[3, 4]]
-- >>> t0 :. t1 :. t2 :. t3 :. t4 :. HNil = chunk @6 @0 (ones :: CPUTensor 'D.Float '[19, 4])
-- >>> dtype &&& shape $ t0
-- (Float,[4,4])
-- >>> dtype &&& shape $ t1
-- (Float,[4,4])
-- >>> dtype &&& shape $ t2
-- (Float,[4,4])
-- >>> dtype &&& shape $ t3
-- (Float,[4,4])
-- >>> dtype &&& shape $ t4
-- (Float,[3,4])
chunk ::
  forall chunks dim shape dtype device tensorChunks.
  ( KnownNat chunks,
    KnownNat dim,
    tensorChunks ~ Chunk chunks dim shape dtype device,
    ATen.Castable (HList tensorChunks) [D.ATenTensor]
  ) =>
  -- | input tensor
  Tensor device dtype shape ->
  -- | output list of tensors
  HList tensorChunks
chunk :: forall {k} (chunks :: Nat) (dim :: Nat) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat))
       (tensorChunks :: [k]).
(KnownNat chunks, KnownNat dim,
 tensorChunks ~ Chunk chunks dim shape dtype device,
 Castable (HList tensorChunks) [ForeignPtr Tensor]) =>
Tensor device dtype shape -> HList tensorChunks
chunk Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr TensorList)
ATen.Managed.chunk_tll Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @chunks :: Int) (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)

-- | clamp
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for the minimum and maximum values?
--
-- >>> dtype &&& shape $ clamp 0 1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
clamp ::
  forall shape dtype device a.
  (D.Scalar a) =>
  -- | minimum value
  a ->
  -- | maximum value
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
clamp :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
clamp a
min a
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_tss Tensor device dtype shape
input a
min a
max

-- | clampMax
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for the maximum value?
--
-- >>> dtype &&& shape $ clampMax 1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
clampMax ::
  forall shape dtype device a.
  (D.Scalar a) =>
  -- | maximum value
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
clampMax :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
clampMax a
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_max_ts Tensor device dtype shape
input a
max

-- | clampMin
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar for the minimum value?
--
-- >>> dtype &&& shape $ clampMin 0 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
clampMin ::
  forall shape dtype device a.
  (D.Scalar a) =>
  -- | minimum value
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
clampMin :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
clampMin a
min Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_min_ts Tensor device dtype shape
input a
min

-- | cudnnIsAcceptable
-- TODO: calling this probably makes only sense when the device is CUDA
cudnnIsAcceptable ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Bool
cudnnIsAcceptable :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.cudnn_is_acceptable_t Tensor device dtype shape
input

-- constant_pad_nd :: Tensor device dtype shape -> [Int] -> Float -> Tensor device dtype shape
-- constant_pad_nd _input _pad _value = unsafePerformIO $ (ATen.cast3 ATen.Managed.constant_pad_nd_tls) _input _pad _value

constantPadNd1d ::
  forall (pad :: (Nat, Nat)) n dtype device.
  (All KnownNat '[Torch.Typed.Auxiliary.Fst pad, Torch.Typed.Auxiliary.Snd pad, n]) =>
  Float ->
  Tensor device dtype '[n] ->
  Tensor device dtype '[n + Torch.Typed.Auxiliary.Fst pad + Torch.Typed.Auxiliary.Snd pad]
constantPadNd1d :: forall (pad :: (Nat, Nat)) (n :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
All KnownNat '[Fst pad, Snd pad, n] =>
Float
-> Tensor device dtype '[n]
-> Tensor device dtype '[(n + Fst pad) + Snd pad]
constantPadNd1d Float
value Tensor device dtype '[n]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.constant_pad_nd_tls
      Tensor device dtype '[n]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst pad), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd pad)] :: [Int])
      Float
value

-- convolution :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> Bool -> [Int] -> Int -> Tensor device dtype shape
-- convolution _input _weight _bias _stride _padding _dilation _transposed _output_padding _groups = unsafePerformIO $ (ATen.cast9 ATen.Managed.convolution_tttlllbll) _input _weight _bias _stride _padding _dilation _transposed _output_padding _groups

type ConvSideCheck (inputSize :: Nat) (kernelSize :: Nat) (stride :: Nat) (padding :: Nat) (outputSize :: Nat) =
  ( -- kernel size and stride must be > 0
    1 <= kernelSize,
    1 <= stride,
    -- kernel size can't be greater than actual input size
    -- ToDo: Do not use '>=' on constraint to avoid reduction-stack-overflow.
    (kernelSize - 1) <= (inputSize + (2 * padding)),
    -- output size must be greater than 0
    1 <= outputSize,
    -- output formulation:
    outputSize ~ ConvOutputSize inputSize kernelSize stride padding
  )

-- | ConvOutputSize
--
-- >>> :kind! ConvOutputSize 4 1 1 0
-- ConvOutputSize 4 1 1 0 :: Natural
-- = 4
type family ConvOutputSize (inputSize :: Nat) (kernelSize :: Nat) (stride :: Nat) (padding :: Nat) :: Nat where
  ConvOutputSize inputSize kernelSize stride padding = (Div ((inputSize + (2 * padding)) - kernelSize) stride) + 1

-- | conv1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = conv1d @1 @0 (ones :: CPUTensor 'D.Float '[10, 3, 1]) (ones :: CPUTensor 'D.Float '[10]) (ones :: CPUTensor 'D.Float '[1, 3, 4])
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 10, 4]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[Float]]]) $ t
-- (Float,([1,10,4],[[[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0]]]))
conv1d ::
  forall
    (stride :: Nat)
    (padding :: Nat)
    inputChannelSize
    outputChannelSize
    kernelSize
    inputSize
    batchSize
    outputSize
    dtype
    device.
  ( All
      KnownNat
      '[ stride,
         padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize,
         inputSize,
         batchSize,
         outputSize
       ],
    ConvSideCheck inputSize kernelSize stride padding outputSize
  ) =>
  -- | weight
  Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize] ->
  -- | bias
  Tensor device dtype '[outputChannelSize] ->
  -- | input
  Tensor device dtype '[batchSize, inputChannelSize, inputSize] ->
  -- | output
  Tensor device dtype '[batchSize, outputChannelSize, outputSize]
conv1d :: forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
       (outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
     inputSize, batchSize, outputSize],
 ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
  device dtype '[outputChannelSize, inputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
conv1d Tensor
  device dtype '[outputChannelSize, inputChannelSize, kernelSize]
weight Tensor device dtype '[outputChannelSize]
bias Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv1d_tttllll
      Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
      Tensor
  device dtype '[outputChannelSize, inputChannelSize, kernelSize]
weight
      Tensor device dtype '[outputChannelSize]
bias
      (forall (n :: Nat). KnownNat n => Int
natValI @stride :: Int)
      (forall (n :: Nat). KnownNat n => Int
natValI @padding :: Int)
      (Int
1 :: Int)
      (Int
1 :: Int)

-- | conv2d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = conv2d @'(1, 1) @'(0, 0) (ones :: CPUTensor 'D.Float '[10, 3, 1, 1]) (ones :: CPUTensor 'D.Float '[10]) (ones :: CPUTensor 'D.Float '[1, 3, 4, 5])
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 10, 4, 5]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[[Float]]]]) $ t
-- (Float,([1,10,4,5],[[[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]]]]))
conv2d ::
  forall
    (stride :: (Nat, Nat))
    (padding :: (Nat, Nat))
    inputChannelSize
    outputChannelSize
    kernelSize0
    kernelSize1
    inputSize0
    inputSize1
    batchSize
    outputSize0
    outputSize1
    dtype
    device.
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         inputSize0,
         inputSize1,
         batchSize,
         outputSize0,
         outputSize1
       ],
    ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  -- | weight
  Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1] ->
  -- | bias
  Tensor device dtype '[outputChannelSize] ->
  -- | input
  Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2d :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst stride, Snd stride, Fst padding, Snd padding,
     inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
     inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2d Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv2d_tttllll
      Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
      Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight
      Tensor device dtype '[outputChannelSize]
bias
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
      ([Int
1, Int
1] :: [Int])
      (Int
1 :: Int)

-- | conv3d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = conv3d @'(1, 1, 1) @'(0, 0, 0) (ones :: CPUTensor 'D.Float '[10, 3, 1, 1, 1]) (ones :: CPUTensor 'D.Float '[10]) (ones :: CPUTensor 'D.Float '[1, 3, 4, 5, 6])
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 10, 4, 5, 6]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[[[Float]]]]]) $ t
-- (Float,([1,10,4,5,6],[[[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]]]]))
conv3d ::
  forall
    (stride :: (Nat, Nat, Nat))
    (padding :: (Nat, Nat, Nat))
    inputChannelSize
    outputChannelSize
    kernelSize0
    kernelSize1
    kernelSize2
    inputSize0
    inputSize1
    inputSize2
    batchSize
    outputSize0
    outputSize1
    outputSize2
    dtype
    device.
  ( All
      KnownNat
      '[ Fst3 stride,
         Snd3 stride,
         Trd3 stride,
         Fst3 padding,
         Snd3 padding,
         Trd3 padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         kernelSize2,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize
       ],
    ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
    ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
  ) =>
  -- | weight
  Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1, kernelSize2] ->
  -- | bias
  Tensor device dtype '[outputChannelSize] ->
  -- | input
  Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2] ->
  -- | output
  Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2]
conv3d :: forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
       (outputSize2 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
     Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
     kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
     inputSize2, batchSize],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
 ConvSideCheck
   inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device
     dtype
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1,
       outputSize2]
conv3d Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv3d_tttllll
      Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
      Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
weight
      Tensor device dtype '[outputChannelSize]
bias
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
      ([Int
1, Int
1, Int
1] :: [Int])
      (Int
1 :: Int)

-- | convTBC
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- 1D convolution over an input of shape `[timeSize, batchSize, inputChannels]`.

-- >>> dtype &&& shape $ convTBC @1 (ones :: CPUTensor 'D.Float '[1,4,5]) (ones :: CPUTensor 'D.Float '[5]) (ones :: CPUTensor 'D.Float '[3,3,4])
-- (Float,[5,3,5])
-- >>> dtype &&& shape $ convTBC @0 (ones :: CPUTensor 'D.Float '[1,4,5]) (ones :: CPUTensor 'D.Float '[5]) (ones :: CPUTensor 'D.Float '[2,3,4])
-- (Float,[3,3,5])
-- >>> dtype &&& shape $ convTBC @0 (ones :: CPUTensor 'D.Float '[2,4,5]) (ones :: CPUTensor 'D.Float '[5]) (ones :: CPUTensor 'D.Float '[2,3,4])
-- (Float,[2,3,5])
convTBC ::
  forall padding timeSize batchSize kernelSize inputChannels outputChannels dtype device.
  (KnownNat padding) =>
  Tensor device dtype '[kernelSize, inputChannels, outputChannels] ->
  Tensor device dtype '[outputChannels] ->
  Tensor device dtype '[timeSize, batchSize, inputChannels] ->
  Tensor device dtype '[timeSize + padding * 2 + 1 - kernelSize, batchSize, outputChannels]
convTBC :: forall (padding :: Nat) (timeSize :: Nat) (batchSize :: Nat)
       (kernelSize :: Nat) (inputChannels :: Nat) (outputChannels :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
KnownNat padding =>
Tensor device dtype '[kernelSize, inputChannels, outputChannels]
-> Tensor device dtype '[outputChannels]
-> Tensor device dtype '[timeSize, batchSize, inputChannels]
-> Tensor
     device
     dtype
     '[((timeSize + (padding * 2)) + 1) - kernelSize, batchSize,
       outputChannels]
convTBC Tensor device dtype '[kernelSize, inputChannels, outputChannels]
weight Tensor device dtype '[outputChannels]
bias Tensor device dtype '[timeSize, batchSize, inputChannels]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_tbc_tttl Tensor device dtype '[timeSize, batchSize, inputChannels]
input Tensor device dtype '[kernelSize, inputChannels, outputChannels]
weight Tensor device dtype '[outputChannels]
bias (forall (n :: Nat). KnownNat n => Int
natValI @padding)

-- | convTranspose1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = convTranspose1d @1 @0 (ones :: CPUTensor 'D.Float '[3, 10, 1]) (ones :: CPUTensor 'D.Float '[10]) (ones :: CPUTensor 'D.Float '[1, 3, 4])
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 10, 4]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[Float]]]) $ t
-- (Float,([1,10,4],[[[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0]]]))
convTranspose1d ::
  forall
    (stride :: Nat)
    (padding :: Nat)
    inputChannelSize
    outputChannelSize
    kernelSize
    inputSize
    batchSize
    outputSize
    dtype
    device.
  ( All
      KnownNat
      '[ stride,
         padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize,
         inputSize,
         batchSize,
         outputSize
       ],
    ConvSideCheck inputSize kernelSize stride padding outputSize
  ) =>
  -- | weight
  Tensor device dtype '[inputChannelSize, outputChannelSize, kernelSize] ->
  -- | bias
  Tensor device dtype '[outputChannelSize] ->
  -- | input
  Tensor device dtype '[batchSize, inputChannelSize, inputSize] ->
  -- | output
  Tensor device dtype '[batchSize, outputChannelSize, outputSize]
convTranspose1d :: forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
       (outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
     inputSize, batchSize, outputSize],
 ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
  device dtype '[inputChannelSize, outputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
convTranspose1d Tensor
  device dtype '[inputChannelSize, outputChannelSize, kernelSize]
weight Tensor device dtype '[outputChannelSize]
bias Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_transpose1d_tttllll
      Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
      Tensor
  device dtype '[inputChannelSize, outputChannelSize, kernelSize]
weight
      Tensor device dtype '[outputChannelSize]
bias
      (forall (n :: Nat). KnownNat n => Int
natValI @stride :: Int)
      (forall (n :: Nat). KnownNat n => Int
natValI @padding :: Int)
      (Int
0 :: Int)
      (Int
1 :: Int)

-- | convTranspose2d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = convTranspose2d @'(1, 1) @'(0, 0) (ones :: CPUTensor 'D.Float '[3, 10, 1, 1]) (ones :: CPUTensor 'D.Float '[10]) (ones :: CPUTensor 'D.Float '[1, 3, 4, 5])
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 10, 4, 5]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[[Float]]]]) $ t
-- (Float,([1,10,4,5],[[[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0]]]]))
convTranspose2d ::
  forall
    (stride :: (Nat, Nat))
    (padding :: (Nat, Nat))
    inputChannelSize
    outputChannelSize
    kernelSize0
    kernelSize1
    inputSize0
    inputSize1
    batchSize
    outputSize0
    outputSize1
    dtype
    device.
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         inputSize0,
         inputSize1,
         batchSize,
         outputSize0,
         outputSize1
       ],
    ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  -- | weight
  Tensor device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1] ->
  -- | bias
  Tensor device dtype '[outputChannelSize] ->
  -- | input
  Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
convTranspose2d :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst stride, Snd stride, Fst padding, Snd padding,
     inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
     inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
convTranspose2d Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_transpose2d_tttllll
      Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
      Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight
      Tensor device dtype '[outputChannelSize]
bias
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
      ([Int
0, Int
0] :: [Int])
      (Int
1 :: Int)

-- | convTranspose3d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = convTranspose3d @'(1, 1, 1) @'(0, 0, 0) (ones :: CPUTensor 'D.Float '[3, 10, 1, 1, 1]) (ones :: CPUTensor 'D.Float '[10]) (ones :: CPUTensor 'D.Float '[1, 3, 4, 5, 6])
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 10, 4, 5, 6]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[[[Float]]]]]) $ t
-- (Float,([1,10,4,5,6],[[[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]],[[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]],[[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0],[4.0,4.0,4.0,4.0,4.0,4.0]]]]]))
convTranspose3d ::
  forall
    (stride :: (Nat, Nat, Nat))
    (padding :: (Nat, Nat, Nat))
    inputChannelSize
    outputChannelSize
    kernelSize0
    kernelSize1
    kernelSize2
    inputSize0
    inputSize1
    inputSize2
    batchSize
    outputSize0
    outputSize1
    outputSize2
    dtype
    device.
  ( All
      KnownNat
      '[ Fst3 stride,
         Snd3 stride,
         Trd3 stride,
         Fst3 padding,
         Snd3 padding,
         Trd3 padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         kernelSize2,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize
       ],
    ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
    ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
  ) =>
  -- | weight
  Tensor device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1, kernelSize2] ->
  -- | bias
  Tensor device dtype '[outputChannelSize] ->
  -- | input
  Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2] ->
  -- | output
  Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2]
convTranspose3d :: forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
       (outputSize2 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
     Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
     kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
     inputSize2, batchSize],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
 ConvSideCheck
   inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device
     dtype
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1,
       outputSize2]
convTranspose3d Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_transpose3d_tttllll
      Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
      Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
weight
      Tensor device dtype '[outputChannelSize]
bias
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
      ([Int
0, Int
0, Int
0] :: [Int])
      (Int
1 :: Int)

-- | cosh
--
-- >>> dtype &&& shape $ cosh (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
cosh ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
cosh :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
cosh Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.cosh_t Tensor device dtype shape
input

-- cosine_embedding_loss :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> Int -> Tensor device dtype shape
-- cosine_embedding_loss _input1 _input2 _target _margin _reduction = unsafePerformIO $ (ATen.cast5 ATen.Managed.cosine_embedding_loss_tttdl) _input1 _input2 _target _margin _reduction

-- cudnn_affine_grid_generator :: Tensor device dtype shape -> Int -> Int -> Int -> Int -> Tensor device dtype shape
-- cudnn_affine_grid_generator _theta _N _C _H _W = unsafePerformIO $ (ATen.cast5 ATen.Managed.cudnn_affine_grid_generator_tllll) _theta _N _C _H _W

-- cudnn_batch_norm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Double -> Double -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- cudnn_batch_norm _input _weight _bias _running_mean _running_var _training _exponential_average_factor _epsilon = unsafePerformIO $ (ATen.cast8 ATen.Managed.cudnn_batch_norm_tttttbdd) _input _weight _bias _running_mean _running_var _training _exponential_average_factor _epsilon

-- cudnn_convolution :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> Int -> Bool -> Bool -> Tensor device dtype shape
-- cudnn_convolution _input _weight _bias _padding _stride _dilation _groups _benchmark _deterministic = unsafePerformIO $ (ATen.cast9 ATen.Managed.cudnn_convolution_tttllllbb) _input _weight _bias _padding _stride _dilation _groups _benchmark _deterministic

-- cudnn_convolution_transpose :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> [Int] -> Int -> Bool -> Bool -> Tensor device dtype shape
-- cudnn_convolution_transpose _input _weight _bias _padding _output_padding _stride _dilation _groups _benchmark _deterministic = unsafePerformIO $ (ATen.ATen.cast10 ATen.Managed.cudnn_convolution_transpose_tttlllllbb) _input _weight _bias _padding _output_padding _stride _dilation _groups _benchmark _deterministic

-- cudnn_grid_sampler :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- cudnn_grid_sampler _input _grid = unsafePerformIO $ (ATen.cast2 ATen.Managed.cudnn_grid_sampler_tt) _input _grid

-- | Det
--
-- >>> :kind! Det '[2,2]
-- Det '[2,2] :: [Natural]
-- = '[]
-- >>> :kind! Det '[3,2,2]
-- Det '[3,2,2] :: [Natural]
-- = '[3]
type family Det (shape :: [Nat]) :: [Nat] where
  Det (n : n : '[]) = '[]
  Det (b : n : n : '[]) = '[b]
  Det _ = TypeError (Text "This shape must be square matrix or batch + squre matrix.")

-- | det
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ det (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
-- >>> dtype &&& shape $ det (ones :: CPUTensor 'D.Float '[3,2,2])
-- (Float,[3])
det ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype (Det shape)
det :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype (Det shape)
det Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.det_t Tensor device dtype shape
input

type family DimsDistinctAscendingCheck (dim1 :: Nat) (dim2 :: Nat) (cmp :: Ordering) :: Constraint where
  DimsDistinctAscendingCheck _ _ 'LT = ()
  DimsDistinctAscendingCheck dim1 dim2 _ =
    TypeError
      ( Text "Dimensions must be distinct and in ascending order, but got "
          :<>: ShowType dim1
          :<>: Text ", "
          :<>: ShowType dim2
      )

type family DimsDistinctAscending (dim1 :: Nat) (dim2 :: Nat) :: Constraint where
  DimsDistinctAscending dim1 dim2 = DimsDistinctAscendingCheck dim1 dim2 (CmpNat dim1 dim2)

type family DiagEmbedShapeImpl (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat]) (n :: Nat) :: [Nat] where
  DiagEmbedShapeImpl dim1 dim2 shape n = Insert dim1 n (Insert (dim2 - 1) n (Init shape))

type family DiagEmbedShape (index :: Nat) (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat]) :: [Nat] where
  DiagEmbedShape index dim1 dim2 shape = DiagEmbedShapeImpl dim1 dim2 shape (Last shape + index)

-- | diagEmbed
--
-- >>> dtype &&& shape $ diagEmbed @0 @1 @2 Upper (ones :: CPUTensor 'D.Float '[2,3])
-- (Float,[2,3,3])
-- >>> dtype &&& shape $ diagEmbed @1 @0 @2 Upper (ones :: CPUTensor 'D.Float '[2,3])
-- (Float,[4,2,4])
diagEmbed ::
  forall index dim1 dim2 shape shape' device dtype.
  ( KnownNat index,
    KnownNat dim1,
    KnownNat dim2,
    shape' ~ DiagEmbedShape index dim1 dim2 shape,
    DimsDistinctAscending dim1 dim2,
    StandardDTypeValidation device dtype
  ) =>
  Tri ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
diagEmbed :: forall (index :: Nat) (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat])
       (shape' :: [Nat]) (device :: (DeviceType, Nat)) (dtype :: DType).
(KnownNat index, KnownNat dim1, KnownNat dim2,
 shape' ~ DiagEmbedShape index dim1 dim2 shape,
 DimsDistinctAscending dim1 dim2,
 StandardDTypeValidation device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
diagEmbed Tri
tri Tensor device dtype shape
t =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4
      ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.diag_embed_tlll
      Tensor device dtype shape
t
      (if Tri -> Bool
isUpper Tri
tri then forall (n :: Nat). KnownNat n => Int
natValI @index else - forall (n :: Nat). KnownNat n => Int
natValI @index)
      (forall (n :: Nat). KnownNat n => Int
natValI @dim1)
      (forall (n :: Nat). KnownNat n => Int
natValI @dim2)

type family DiagflatShapeImpl (d :: Nat) :: [Nat] where
  DiagflatShapeImpl d = '[d, d]

type family DiagflatShape (index :: Nat) (shape :: [Nat]) :: [Nat] where
  DiagflatShape index shape = DiagflatShapeImpl (Numel shape + index)

-- | diagflat
--
-- >>> dtype &&& shape $ diagflat @0 Upper (ones :: CPUTensor 'D.Float '[3])
-- (Float,[3,3])
-- >>> dtype &&& shape $ diagflat @1 Upper (ones :: CPUTensor 'D.Float '[3])
-- (Float,[4,4])
-- >>> dtype &&& shape $ diagflat @0 Upper (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[4,4])
diagflat ::
  forall index shape shape' device dtype.
  ( KnownNat index,
    shape' ~ DiagflatShape index shape,
    StandardDTypeValidation device dtype
  ) =>
  Tri ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
diagflat :: forall (index :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (device :: (DeviceType, Nat)) (dtype :: DType).
(KnownNat index, shape' ~ DiagflatShape index shape,
 StandardDTypeValidation device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
diagflat Tri
tri Tensor device dtype shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.diagflat_tl Tensor device dtype shape
t forall a b. (a -> b) -> a -> b
$
    case Tri
tri of
      Tri
Upper -> forall (n :: Nat). KnownNat n => Int
natValI @index
      Tri
Lower -> - forall (n :: Nat). KnownNat n => Int
natValI @index

type family NDimAtLeastCheck (ndim :: Nat) (shape :: [Nat]) (cmp :: Ordering) :: Constraint where
  NDimAtLeastCheck ndim shape 'GT =
    TypeError
      ( Text "Input must have at least "
          :<>: ShowType ndim
          :<>: Text " dimensions, but got "
          :<>: ShowType (ListLength shape)
      )
  NDimAtLeastCheck _ _ _ = ()

type family NDimAtLeast (ndim :: Nat) (shape :: [Nat]) :: Constraint where
  NDimAtLeast ndim shape = NDimAtLeastCheck ndim shape (CmpNat ndim (ListLength shape))

type family DiagonalShape (tri :: Tri) (index :: Nat) (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat]) :: [Nat] where
  DiagonalShape tri index dim1 dim2 shape =
    Remove (Remove shape dim2) dim1 ++ '[DiagSize tri index (Index shape dim1) (Index shape dim2)]

-- | diagonal
--
-- >>> dtype &&& shape $ diagonal @'Upper @0 @0 @1 (ones :: CPUTensor 'D.Float '[3,3])
-- (Float,[3])
-- >>> dtype &&& shape $ diagonal @'Upper @1 @0 @1 (ones :: CPUTensor 'D.Float '[3,3])
-- (Float,[2])
-- >>> dtype &&& shape $ diagonal @'Lower @1 @1 @2 (ones :: CPUTensor 'D.Float '[2,5,4,2])
-- (Float,[2,2,4])
diagonal ::
  forall tri index dim1 dim2 shape shape' device dtype.
  ( KnownTri tri,
    KnownNat index,
    KnownNat dim1,
    KnownNat dim2,
    NDimAtLeast 2 shape,
    DimsDistinctAscending dim1 dim2,
    shape' ~ DiagonalShape tri index dim1 dim2 shape,
    StandardDTypeValidation device dtype
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
diagonal :: forall (tri :: Tri) (index :: Nat) (dim1 :: Nat) (dim2 :: Nat)
       (shape :: [Nat]) (shape' :: [Nat]) (device :: (DeviceType, Nat))
       (dtype :: DType).
(KnownTri tri, KnownNat index, KnownNat dim1, KnownNat dim2,
 NDimAtLeast 2 shape, DimsDistinctAscending dim1 dim2,
 shape' ~ DiagonalShape tri index dim1 dim2 shape,
 StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape'
diagonal Tensor device dtype shape
t =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4
      ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.diagonal_tlll
      Tensor device dtype shape
t
      (if Tri -> Bool
isUpper (forall (tri :: Tri). KnownTri tri => Tri
triVal @tri) then forall (n :: Nat). KnownNat n => Int
natValI @index else - forall (n :: Nat). KnownNat n => Int
natValI @index)
      (forall (n :: Nat). KnownNat n => Int
natValI @dim1)
      (forall (n :: Nat). KnownNat n => Int
natValI @dim2)

type family DotDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  DotDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  DotDTypeIsValid '( 'D.CUDA, deviceIndex) dtype = DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype
  DotDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | dot product
-- Note that this function does not broadcast.
dot ::
  forall size dtype device.
  DotDTypeIsValid device dtype =>
  -- | input
  Tensor device dtype '[size] ->
  -- | other input
  Tensor device dtype '[size] ->
  -- | dot product
  Tensor device dtype '[]
dot :: forall (size :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
DotDTypeIsValid device dtype =>
Tensor device dtype '[size]
-> Tensor device dtype '[size] -> Tensor device dtype '[]
dot Tensor device dtype '[size]
input Tensor device dtype '[size]
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.dot_tt Tensor device dtype '[size]
input Tensor device dtype '[size]
other

-- einsum :: String -> [Tensor device dtype shape] -> Tensor device dtype shape
-- einsum _equation _tensors = unsafePerformIO $ (ATen.cast2 ATen.Managed.einsum_sl) _equation _tensors

class KnownMaybeNat (n :: Maybe Nat) where
  maybeNatVal :: Maybe Integer

instance (KnownNat n) => KnownMaybeNat (Just n) where
  maybeNatVal :: Maybe Integer
maybeNatVal = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @n

instance KnownMaybeNat Nothing where
  maybeNatVal :: Maybe Integer
maybeNatVal = forall a. Maybe a
Nothing

type family PaddingIdxCheck (idx :: Maybe Nat) (numEmbeds :: Nat) :: Constraint where
  PaddingIdxCheck (Just n) numEmbeds = n + 1 <= numEmbeds
  PaddingIdxCheck Nothing _ = ()

-- | embedding
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: what about sparsity here?
-- TODO: what output dtypes are supported?
--
-- >>> weights = fromJust [[1, 1], [2, 2], [3, 3], [4, 4]] :: CPUTensor 'D.Float '[4, 2]
-- >>> indices = fromJust [[0], [2], [0], [1]] :: CPUTensor 'D.Int64 '[4, 1]
-- >>> t = embedding @('Just 1) False False weights indices
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[4, 1, 2]
--
-- -- libtorch 1.7
-- -- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[Float]]]) $ t
-- -- (Float,([4,1,2],[[[1.0,1.0]],[[3.0,3.0]],[[1.0,1.0]],[[2.0,2.0]]]))
-- --
-- -- libtorch 1.8
-- -- The behavior of libtorch 1.8 changes. See https://github.com/pytorch/pytorch/issues/53368
-- -- (Float,([4,1,2],[[[1.0,1.0]],[[3.0,3.0]],[[1.0,1.0]],[[0.0,0.0]]]))
-- --
-- -- libtorch 1.8.1
-- -- The behavior of libtorch 1.8.1 is reverted.
-- -- (Float,([4,1,2],[[[1.0,1.0]],[[3.0,3.0]],[[1.0,1.0]],[[2.0,2.0]]]))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[Float]]]) $ t
-- (Float,([4,1,2],[[[1.0,1.0]],[[3.0,3.0]],[[1.0,1.0]],[[2.0,2.0]]]))
-- >>> t = embedding @'Nothing False False weights indices
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[4, 1, 2]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[Float]]]) $ t
-- (Float,([4,1,2],[[[1.0,1.0]],[[3.0,3.0]],[[1.0,1.0]],[[2.0,2.0]]]))
embedding ::
  forall (paddingIdx :: Maybe Nat) numEmbeds embedDim shape dtype device.
  ( KnownMaybeNat paddingIdx,
    PaddingIdxCheck paddingIdx numEmbeds
  ) =>
  -- | whether or not to scale the gradient by the frequencies
  Bool ->
  -- | whether or not the embedding is sparse
  Bool ->
  -- | weights
  Tensor device dtype '[numEmbeds, embedDim] ->
  -- | indices
  Tensor device 'D.Int64 shape ->
  -- | output
  Tensor device dtype (Reverse (embedDim ': Reverse shape))
embedding :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding Bool
scaleGradByFreq Bool
sparse Tensor device dtype '[numEmbeds, embedDim]
weights Tensor device 'Int64 shape
indices =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.embedding_ttlbb Tensor device dtype '[numEmbeds, embedDim]
weights Tensor device 'Int64 shape
indices Int
paddingIdx Bool
scaleGradByFreq Bool
sparse
  where
    paddingIdx :: Int
    paddingIdx :: Int
paddingIdx = case forall (n :: Maybe Nat). KnownMaybeNat n => Maybe Integer
maybeNatVal @paddingIdx of
      Just Integer
idx -> forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
idx
      Maybe Integer
Nothing -> -Int
1

-- embedding_bag :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Int -> Bool -> Tensor device dtype shape -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- embedding_bag _weight _indices _offsets _scale_grad_by_freq _mode _sparse _per_sample_weights = unsafePerformIO $ (ATen.cast7 ATen.Managed.embedding_bag_tttblbt) _weight _indices _offsets _scale_grad_by_freq _mode _sparse _per_sample_weights

-- | emptyLike
-- TODO: this seems quite unsafe, the values of this tensor will be random
--
-- >>> t <- emptyLike (ones :: CPUTensor 'D.Float '[3,4,5])
-- >>> dtype &&& shape $ t
-- (Float,[3,4,5])
emptyLike ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  IO (Tensor device dtype shape)
emptyLike :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
emptyLike Tensor device dtype shape
input = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.empty_like_t Tensor device dtype shape
input

-- | erfc
--
-- >>> dtype &&& shape $ erfc (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
erfc ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
erfc :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
erfc Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.erfc_t Tensor device dtype shape
input

-- | expm1
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ expm1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
expm1 ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
expm1 :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
expm1 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.expm1_t Tensor device dtype shape
input

-- | expand
-- TODO: figure out what the `implicit` boolean value does
--
-- >>> t = ones :: CPUTensor 'D.Float '[2]
-- >>> t' = expand @'[3, 1, 2] False t
-- >>> dtype &&& shape $ t'
-- (Float,[3,1,2])
-- >>> t'' = expand @'[3, 1, 2] True t
-- >>> dtype &&& shape $ t''
-- (Float,[3,1,2])
-- >>> toInt (all (t' ==. t'')) == 1
-- True
expand ::
  forall shape' shape dtype device.
  ( KnownShape shape',
    shape' ~ Broadcast shape shape'
  ) =>
  -- | some boolean value with unknown function
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
expand :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand Bool
someBool Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.tensor_expand_lb Tensor device dtype shape
input (forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape') Bool
someBool

-- flatten :: Tensor device dtype shape -> Int -> Int -> Tensor device dtype shape
-- flatten _input _start_dim _end_dim = unsafePerformIO $ (ATen.cast3 ATen.Managed.flatten_tll) _input _start_dim _end_dim

-- | flattenAll
--
-- >>> t = flattenAll (ones :: CPUTensor 'D.Float '[3,2])
-- >>> dtype &&& shape $ t
-- (Float,[6])
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[6]
flattenAll ::
  forall shape dtype device.
  KnownShape shape =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[Product shape]
flattenAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownShape shape =>
Tensor device dtype shape -> Tensor device dtype '[Product shape]
flattenAll Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.flatten_tll Tensor device dtype shape
input (Int
0 :: Int) (-Int
1 :: Int)

-- | frac
--
-- >>> dtype &&& shape $ frac (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
frac ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
frac :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
frac Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.frac_t Tensor device dtype shape
input

-- | full like
--
-- >>> dtype &&& shape $ fullLike 3.0 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
fullLike ::
  forall shape dtype device.
  -- | fill value
  Float ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
fullLike :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Float -> Tensor device dtype shape -> Tensor device dtype shape
fullLike Float
fillValue Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.full_like_ts Tensor device dtype shape
input Float
fillValue

-- grid_sampler :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Int -> Tensor device dtype shape
-- grid_sampler _input _grid _interpolation_mode _padding_mode = unsafePerformIO $ (ATen.cast4 ATen.Managed.grid_sampler_ttll) _input _grid _interpolation_mode _padding_mode

-- grid_sampler_2d :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Int -> Tensor device dtype shape
-- grid_sampler_2d _input _grid _interpolation_mode _padding_mode = unsafePerformIO $ (ATen.cast4 ATen.Managed.grid_sampler_2d_ttll) _input _grid _interpolation_mode _padding_mode

-- grid_sampler_3d :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Int -> Tensor device dtype shape
-- grid_sampler_3d _input _grid _interpolation_mode _padding_mode = unsafePerformIO $ (ATen.cast4 ATen.Managed.grid_sampler_3d_ttll) _input _grid _interpolation_mode _padding_mode

-- hinge_embedding_loss :: Tensor device dtype shape -> Tensor device dtype shape -> Double -> Int -> Tensor device dtype shape
-- hinge_embedding_loss _input _target _margin _reduction = unsafePerformIO $ (ATen.cast4 ATen.Managed.hinge_embedding_loss_ttdl) _input _target _margin _reduction

-- ger :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- ger _input _vec2 = unsafePerformIO $ (ATen.cast2 ATen.Managed.ger_tt) _input _vec2

-- group_norm :: Tensor device dtype shape -> Int -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> Bool -> Tensor device dtype shape
-- group_norm _input _num_groups _weight _bias _eps _cudnn_enabled = unsafePerformIO $ (ATen.cast6 ATen.Managed.group_norm_tlttdb) _input _num_groups _weight _bias _eps _cudnn_enabled

-- fft :: Tensor device dtype shape -> Int -> Bool -> Tensor device dtype shape
-- fft _input _signal_ndim _normalized = unsafePerformIO $ (ATen.cast3 ATen.Managed.fft_tlb) _input _signal_ndim _normalized

-- ifft :: Tensor device dtype shape -> Int -> Bool -> Tensor device dtype shape
-- ifft _input _signal_ndim _normalized = unsafePerformIO $ (ATen.cast3 ATen.Managed.ifft_tlb) _input _signal_ndim _normalized

-- rfft :: Tensor device dtype shape -> Int -> Bool -> Bool -> Tensor device dtype shape
-- rfft _input _signal_ndim _normalized _onesided = unsafePerformIO $ (ATen.cast4 ATen.Managed.rfft_tlbb) _input _signal_ndim _normalized _onesided

-- irfft :: Tensor device dtype shape -> Int -> Bool -> Bool -> [Int] -> Tensor device dtype shape
-- irfft _input _signal_ndim _normalized _onesided _signal_sizes = unsafePerformIO $ (ATen.cast5 ATen.Managed.irfft_tlbbl) _input _signal_ndim _normalized _onesided _signal_sizes

-- index :: Tensor device dtype shape -> [Tensor device dtype shape] -> Tensor device dtype shape
-- index _input _indices = unsafePerformIO $ (ATen.cast2 ATen.Managed.index_tl) _input _indices

-- index_copy :: Tensor device dtype shape -> Int -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- index_copy _input _dim _index _source = unsafePerformIO $ (ATen.cast4 ATen.Managed.index_copy_tltt) _input _dim _index _source

-- index_put :: Tensor device dtype shape -> [Tensor device dtype shape] -> Tensor device dtype shape -> Bool -> Tensor device dtype shape
-- index_put _input _indices _values _accumulate = unsafePerformIO $ (ATen.cast4 ATen.Managed.index_put_tltb) _input _indices _values _accumulate

-- instance_norm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Double -> Double -> Bool -> Tensor device dtype shape
-- instance_norm _input _weight _bias _running_mean _running_var _use_input_stats _momentum _eps _cudnn_enabled = unsafePerformIO $ (ATen.cast9 ATen.Managed.instance_norm_tttttbddb) _input _weight _bias _running_mean _running_var _use_input_stats _momentum _eps _cudnn_enabled

-- | isclose
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ isclose 0.1 0.1 False (ones :: CPUTensor 'D.Float '[3,2]) (ones :: CPUTensor 'D.Float '[3,2])
-- (Bool,[3,2])
isclose ::
  forall shape dtype device.
  -- | relative tolerance
  Double ->
  -- | absolute tolerance
  Double ->
  -- | whether or not NaN equals NaN
  Bool ->
  -- | input tensor
  Tensor device dtype shape ->
  -- | other input tensor
  Tensor device dtype shape ->
  -- | output
  Tensor device 'D.Bool shape
isclose :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Double
-> Bool
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device 'Bool shape
isclose Double
rtol Double
atol Bool
equalNaN Tensor device dtype shape
input Tensor device dtype shape
other =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> CDouble
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.isclose_ttddb Tensor device dtype shape
input Tensor device dtype shape
other Double
rtol Double
atol Bool
equalNaN

-- | is NaN
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ isNaN (ones :: CPUTensor 'D.Float '[3,2])
-- (Bool,[3,2])
isNaN ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device 'D.Bool shape
isNaN :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device 'Bool shape
isNaN Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.isnan_t Tensor device dtype shape
input

-- | is distributed
isDistributed ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Bool
isDistributed :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isDistributed Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_distributed_t Tensor device dtype shape
input

-- | is floating point
-- TODO: this can be decided statically
isFloatingPoint ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Bool
isFloatingPoint :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isFloatingPoint Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_floating_point_t Tensor device dtype shape
input

-- | is complex
isComplex ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Bool
isComplex :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isComplex Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_complex_t Tensor device dtype shape
input

-- | is non-zero
-- this operation is only defined for tensors with shape '[] or '[1]
isNonZero ::
  forall shape dtype device.
  (Numel shape ~ 1) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Bool
isNonZero :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(Numel shape ~ 1) =>
Tensor device dtype shape -> Bool
isNonZero Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_nonzero_t Tensor device dtype shape
input

-- | is same size
-- TODO: this can be decided statically
isSameSize ::
  forall shape shape' dtype device.
  -- | input tensor
  Tensor device dtype shape ->
  -- | other input tensor
  Tensor device dtype shape' ->
  -- | output
  Bool
isSameSize :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape' -> Bool
isSameSize Tensor device dtype shape
input Tensor device dtype shape'
other =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO CBool
ATen.Managed.is_same_size_tt Tensor device dtype shape
input Tensor device dtype shape'
other

isSigned ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Bool
isSigned :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isSigned Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_signed_t Tensor device dtype shape
input

-- kl_div :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- kl_div _input _target _reduction = unsafePerformIO $ (ATen.cast3 ATen.Managed.kl_div_ttl) _input _target _reduction

-- kthvalue :: Tensor device dtype shape -> Int -> Int -> Bool -> (Tensor device dtype shape,Tensor device dtype shape)
-- kthvalue _input _k _dim _keepdim = unsafePerformIO $ (ATen.cast4 ATen.Managed.kthvalue_tllb) _input _k _dim _keepdim

-- | layerNorm
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: figure out if and when CUDNN works here, tie it also to the `device`
--
-- >>> t = layerNorm @'[1, 2] @'[2, 1, 2] @'D.Float @'( 'D.CPU, 0) ones ones 0.01 ones
-- >>> :type t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 2]
-- >>> dtype &&& shape $ t
-- (Float,[2,1,2])
layerNorm ::
  forall normalizedShape shape dtype device.
  ( KnownShape normalizedShape,
    IsSuffixOf normalizedShape shape
  ) =>
  -- | weight
  Tensor device dtype normalizedShape ->
  -- | bias
  Tensor device dtype normalizedShape ->
  -- | eps
  Double ->
  -- | input tensor
  Tensor device dtype shape ->
  -- | output tensor
  Tensor device dtype shape
layerNorm :: forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape normalizedShape, IsSuffixOf normalizedShape shape) =>
Tensor device dtype normalizedShape
-> Tensor device dtype normalizedShape
-> Double
-> Tensor device dtype shape
-> Tensor device dtype shape
layerNorm Tensor device dtype normalizedShape
weight Tensor device dtype normalizedShape
bias Double
eps Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.layer_norm_tlttdb
      Tensor device dtype shape
input
      (forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @normalizedShape)
      Tensor device dtype normalizedShape
weight
      Tensor device dtype normalizedShape
bias
      Double
eps
      ( forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype normalizedShape
weight
          Bool -> Bool -> Bool
&& forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype normalizedShape
bias
          Bool -> Bool -> Bool
&& forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype shape
input
      )

-- native_layer_norm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Int -> Int -> Double -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- native_layer_norm _input _weight _bias _M _N _eps = unsafePerformIO $ (ATen.cast6 ATen.Managed.native_layer_norm_tttlld) _input _weight _bias _M _N _eps

-- | linear
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#linear
--
-- >>> w = fromJust [[-0.5, -2,  0.5], [1.5, -0.5, 0.5]] :: CPUTensor 'D.Float '[2, 3]
-- >>> b = fromJust [0, 0.5] :: CPUTensor 'D.Float '[2]
-- >>> t = fromJust [[-2, 0.5, 1], [0.5, 0, 0], [0, 1, 0], [0, 0, 0], [1, -1, 0]] :: CPUTensor 'D.Float '[5, 3]
-- >>> t' = linear w b t
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([5,2],[[0.5,-2.25],[-0.25,1.25],[-2.0,0.0],[0.0,0.5],[1.5,2.5]]))
linear ::
  forall batchSize inputFeatures outputFeatures dtype device.
  Tensor device dtype '[outputFeatures, inputFeatures] ->
  Tensor device dtype '[outputFeatures] ->
  Tensor device dtype '[batchSize, inputFeatures] ->
  Tensor device dtype '[batchSize, outputFeatures]
linear :: forall (batchSize :: Nat) (inputFeatures :: Nat)
       (outputFeatures :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype '[batchSize, inputFeatures]
-> Tensor device dtype '[batchSize, outputFeatures]
linear Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias Tensor device dtype '[batchSize, inputFeatures]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.linear_ttt Tensor device dtype '[batchSize, inputFeatures]
input Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias

-- | linear'
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use the ATen linear function or not here?
-- https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#linear
--
-- >>> w = fromJust [[-0.5, -2,  0.5], [1.5, -0.5, 0.5]] :: CPUTensor 'D.Float '[2, 3]
-- >>> b = fromJust [0, 0.5] :: CPUTensor 'D.Float '[2]
-- >>> t = fromJust [[-2, 0.5, 1], [0.5, 0, 0], [0, 1, 0], [0, 0, 0], [1, -1, 0]] :: CPUTensor 'D.Float '[5, 3]
-- >>> t' = linear' w b t
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- (Float,([5,2],[[0.5,-2.25],[-0.25,1.25],[-2.0,0.0],[0.0,0.5],[1.5,2.5]]))
-- >>> t = fromJust [[[[-2, 0.5, 1], [0.5, 0, 0], [0, 1, 0], [0, 0, 0], [1, -1, 0]], [[-2, 0.5, 1], [0.5, 0, 0], [0, 1, 0], [0, 0, 0], [1, -1, 0]]]] :: CPUTensor 'D.Float '[1, 2, 5, 3]
-- >>> t' = linear' w b t
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 2, 5, 2]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[[[Float]]]]) $ t'
-- (Float,([1,2,5,2],[[[[0.5,-2.25],[-0.25,1.25],[-2.0,0.0],[0.0,0.5],[1.5,2.5]],[[0.5,-2.25],[-0.25,1.25],[-2.0,0.0],[0.0,0.5],[1.5,2.5]]]]))
linear' ::
  forall (inputFeatures :: Nat) (outputFeatures :: Nat) (shape :: [Nat]) (shape' :: [Nat]) dtype device (shape'' :: [Nat]).
  ( shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
    shape' ~ Broadcast shape'' shape''
  ) =>
  -- | weight
  Tensor device dtype '[outputFeatures, inputFeatures] ->
  -- | bias
  Tensor device dtype '[outputFeatures] ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  -- linear' weight bias input = Torch.Static.add (matmul input $ transpose @0 @1 weight) bias
  Tensor device dtype shape'
linear' :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape'' :: [Nat]).
(shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
 shape' ~ Broadcast shape'' shape'') =>
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype shape
-> Tensor device dtype shape'
linear' Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.linear_ttt Tensor device dtype shape
input Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias

-- | mkldnnLinear
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: mkldnnLinear does not return a usuable tensor value and is hence broken
-- TODO: figure out `device` for this
--
-- >>> w = fromJust [[-0.5, -2,  0.5], [1.5, -0.5, 0.5]] :: CPUTensor 'D.Float '[2, 3]
-- >>> b = fromJust [0, 0.5] :: CPUTensor 'D.Float '[2]
-- >>> t = fromJust [[-2, 0.5, 1], [0.5, 0, 0], [0, 1, 0], [0, 0, 0], [1, -1, 0]] :: CPUTensor 'D.Float '[5, 3]
--
-- -- >>> t' = mkldnnLinear (toMKLDNN w) (toMKLDNN b) (toMKLDNN t)
-- -- >>> :type t'
-- -- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[5, 2]
-- -- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[Float]]) $ t'
-- -- (Float,([5,2],[[0.5,-2.25],[-0.25,1.25],[-2.0,0.0],[0.0,0.5],[1.5,2.5]]))
mkldnnLinear ::
  forall batchSize inputFeatures outputFeatures dtype device.
  -- | weight
  Tensor device dtype '[outputFeatures, inputFeatures] ->
  -- | bias
  Tensor device dtype '[outputFeatures] ->
  -- | input
  Tensor device dtype '[batchSize, inputFeatures] ->
  -- | output
  Tensor device dtype '[batchSize, outputFeatures]
mkldnnLinear :: forall (batchSize :: Nat) (inputFeatures :: Nat)
       (outputFeatures :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype '[batchSize, inputFeatures]
-> Tensor device dtype '[batchSize, outputFeatures]
mkldnnLinear Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias Tensor device dtype '[batchSize, inputFeatures]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mkldnn_linear_ttt Tensor device dtype '[batchSize, inputFeatures]
input Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias

-- fbgemm_linear_int8_weight :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Tensor device dtype shape -> Tensor device dtype shape
-- fbgemm_linear_int8_weight _input _weight _packed _col_offsets _weight_scale _weight_zero_point _bias = unsafePerformIO $ (ATen.cast7 ATen.Managed.fbgemm_linear_int8_weight_ttttsst) _input _weight _packed _col_offsets _weight_scale _weight_zero_point _bias

-- fbgemm_linear_quantize_weight :: Tensor device dtype shape -> (Tensor device dtype shape,Tensor device dtype shape,Double,Int)
-- fbgemm_linear_quantize_weight _input = unsafePerformIO $ (ATen.cast1 ATen.Managed.fbgemm_linear_quantize_weight_t) _input

-- fbgemm_pack_gemm_matrix_fp16 :: Tensor device dtype shape -> Tensor device dtype shape
-- fbgemm_pack_gemm_matrix_fp16 _input = unsafePerformIO $ (ATen.cast1 ATen.Managed.fbgemm_pack_gemm_matrix_fp16_t) _input

-- fbgemm_linear_fp16_weight :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- fbgemm_linear_fp16_weight _input _packed_weight _bias = unsafePerformIO $ (ATen.cast3 ATen.Managed.fbgemm_linear_fp16_weight_ttt) _input _packed_weight _bias

-- fbgemm_pack_quantized_matrix :: Tensor device dtype shape -> Int -> Int -> Tensor device dtype shape
-- fbgemm_pack_quantized_matrix _input _K _N = unsafePerformIO $ (ATen.cast3 ATen.Managed.fbgemm_pack_quantized_matrix_tll) _input _K _N

-- fbgemm_is_cpu_supported :: Bool
-- fbgemm_is_cpu_supported  = unsafePerformIO $ (cast0 ATen.Managed.fbgemm_is_cpu_supported)

-- | log
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: will log throw for negative numbers or just generate NaNs? should we return a Maybe?
--
-- >>> dtype &&& shape $ log (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
log ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
log :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
log Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log_t Tensor device dtype shape
input

-- | logDet
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: will logDet throw? and if so, should we return a Maybe?
--
-- >>> dtype &&& shape $ logDet (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
-- >>> dtype &&& shape $ logDet (ones :: CPUTensor 'D.Float '[3,2,2])
-- (Float,[3])
logDet ::
  forall shape' shape dtype device.
  (shape' ~ Det shape) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
logDet :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ Det shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
logDet Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logdet_t Tensor device dtype shape
input

-- | logarithm of the sum of the exponentials
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- See https://pytorch.org/docs/stable/torch.html#torch.logsumexp.
--
-- >>> t = fromJust [[5, 1], [3, 2], [4, 1], [2, 7]] :: CPUTensor 'D.Float '[4, 2]
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Float]) $ logSumExp @1 @DropDim t
-- (Float,([4],[5.01815,3.3132617,4.0485873,7.0067153]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ logSumExp @1 @KeepDim t
-- (Float,([4,1],[[5.01815],[3.3132617],[4.0485873],[7.0067153]]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Float]) $ logSumExp @0 @DropDim t
-- (Float,([2],[5.44019,7.0116277]))
--
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ logSumExp @0 @KeepDim t
-- (Float,([1,2],[[5.44019,7.0116277]]))
logSumExp ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    Reifies dtype D.DType,
    DTypeIsFloatingPoint device dtype,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
logSumExp :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 Reifies dtype DType, DTypeIsFloatingPoint device dtype,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device dtype shape -> Tensor device dtype shape'
logSumExp Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.logsumexp_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- margin_ranking_loss :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> Int -> Tensor device dtype shape
-- margin_ranking_loss _input1 _input2 _target _margin _reduction = unsafePerformIO $ (ATen.cast5 ATen.Managed.margin_ranking_loss_tttdl) _input1 _input2 _target _margin _reduction

-- | matrixPower
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: figure out input shape restrictions, should be matrix or a batched matrix
-- TODO: figure out restrictions on the power, can it be zero or negative?
--
-- >>> dtype &&& shape $ matrixPower 2 (ones :: CPUTensor 'D.Float '[3,4,4])
-- (Float,[3,4,4])
matrixPower ::
  forall shape' shape dtype device.
  (shape' ~ Square shape) =>
  -- | power
  Int ->
  -- | input matrix
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
matrixPower :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ Square shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape'
matrixPower Int
n Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.matrix_power_tl Tensor device dtype shape
input Int
n

-- | maxValues
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ maxValues @0 @KeepDim t
-- (Float,[1,4,5])
-- >>> dtype &&& shape $ maxValues @0 @DropDim t
-- (Float,[4,5])
-- >>> dtype &&& shape $ maxValues @1 @KeepDim t
-- (Float,[3,1,5])
-- >>> dtype &&& shape $ maxValues @1 @DropDim t
-- (Float,[3,5])
maxValues ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
maxValues :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device dtype shape -> Tensor device dtype shape'
maxValues Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.max_values_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | minValues
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ minValues @0 @KeepDim t
-- (Float,[1,4,5])
-- >>> dtype &&& shape $ minValues @0 @DropDim t
-- (Float,[4,5])
-- >>> dtype &&& shape $ minValues @1 @KeepDim t
-- (Float,[3,1,5])
-- >>> dtype &&& shape $ minValues @1 @DropDim t
-- (Float,[3,5])
minValues ::
  forall dim keepOrDropDim shape' shape dtype device.
  ( KnownNat dim,
    KnownKeepOrDropDim keepOrDropDim,
    shape' ~ ConditionalDropDimension shape dim keepOrDropDim
  ) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
minValues :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
       (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
 shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device dtype shape -> Tensor device dtype shape'
minValues Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.min_values_tlb
      Tensor device dtype shape
input
      (forall (n :: Nat). KnownNat n => Int
natValI @dim)
      (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)

-- | maxPool1d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = maxPool1d @1 @1 @0 (ones :: CPUTensor 'D.Float '[1,3,4])
-- >>> shape t
-- [1,3,4]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4]
maxPool1d ::
  forall kernelSize stride padding channelSize inputSize batchSize outputSize dtype device.
  ( All
      KnownNat
      '[ kernelSize,
         stride,
         padding,
         channelSize,
         inputSize,
         batchSize
       ],
    ConvSideCheck inputSize kernelSize stride padding outputSize
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize]
maxPool1d :: forall (kernelSize :: Nat) (stride :: Nat) (padding :: Nat)
       (channelSize :: Nat) (inputSize :: Nat) (batchSize :: Nat)
       (outputSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[kernelSize, stride, padding, channelSize, inputSize, batchSize],
 ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
maxPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.max_pool1d_tllllb
      Tensor device dtype '[batchSize, channelSize, inputSize]
input
      (forall (n :: Nat). KnownNat n => Int
natValI @kernelSize)
      (forall (n :: Nat). KnownNat n => Int
natValI @stride)
      (forall (n :: Nat). KnownNat n => Int
natValI @padding)
      (Int
1 :: Int)
      Bool
False

-- max_pool1d_with_indices :: Tensor device dtype shape -> Int -> Int -> Int -> Int -> Bool -> (Tensor device dtype shape,Tensor device dtype shape)
-- max_pool1d_with_indices _input _kernel_size _stride _padding _dilation _ceil_mode = unsafePerformIO $ (ATen.cast6 ATen.Managed.max_pool1d_with_indices_tllllb) _input _kernel_size _stride _padding _dilation _ceil_mode

-- | maxPool2d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = maxPool2d @'(1,1) @'(1,1) @'(0,0) (ones :: CPUTensor 'D.Float '[1,3,4,5]) -- Skip warning
-- ...
-- >>> shape t
-- [1,3,4,5]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4, 5]
maxPool2d ::
  forall kernelSize stride padding channelSize inputSize0 inputSize1 batchSize outputSize0 outputSize1 dtype device.
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst kernelSize,
         Torch.Typed.Auxiliary.Snd kernelSize,
         Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         channelSize,
         inputSize0,
         inputSize1,
         batchSize
       ],
    ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
maxPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
       (padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
     Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
     batchSize],
 ConvSideCheck
   inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1
   (Snd kernelSize)
   (Snd stride)
   (Snd padding)
   outputSize1) =>
Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
     device dtype '[batchSize, channelSize, outputSize0, outputSize1]
maxPool2d Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.max_pool2d_tllllb
      Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
      ([Int
1, Int
1] :: [Int])
      Bool
False

-- | mkldnnMaxPool2d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: does this function work, that is, does it return values without throwing? when does it work?
-- TODO: this should probably be only callable if the device is MKLDNN?
-- -- >>> t = mkldnnMaxPool2d @'(1,1) @'(1,1) @'(0,0) (toMKLDNN (ones :: CPUTensor 'D.Float '[1,3,4,5]))
-- -- >>> shape t
-- -- [1,3,4,5]
-- -- >>> :t t
-- -- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4, 5]
mkldnnMaxPool2d ::
  forall kernelSize stride padding channelSize inputSize0 inputSize1 batchSize outputSize0 outputSize1 dtype device.
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst kernelSize,
         Torch.Typed.Auxiliary.Snd kernelSize,
         Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         channelSize,
         inputSize0,
         inputSize1,
         batchSize
       ],
    ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
mkldnnMaxPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
       (padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
     Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
     batchSize],
 ConvSideCheck
   inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1
   (Snd kernelSize)
   (Snd stride)
   (Snd padding)
   outputSize1) =>
Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
     device dtype '[batchSize, channelSize, outputSize0, outputSize1]
mkldnnMaxPool2d Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.mkldnn_max_pool2d_tllllb
      Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
      ([Int
1, Int
1] :: [Int])
      Bool
False

-- | quantizedMaxPool2d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: what are quantized functions and when are they available?
-- -- >>> t = quantizedMaxPool2d @'(1,1) @'(1,1) @'(0,0) (ones :: CPUTensor 'D.Float '[1,3,4,5])
-- -- >>> shape t
-- -- [1,3,4,5]
-- -- >>> :t t
-- -- t :: Tensor 'D.Float '[1, 3, 4, 5]
quantizedMaxPool2d ::
  forall kernelSize stride padding channelSize inputSize0 inputSize1 batchSize outputSize0 outputSize1 dtype device.
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst kernelSize,
         Torch.Typed.Auxiliary.Snd kernelSize,
         Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         channelSize,
         inputSize0,
         inputSize1,
         batchSize
       ],
    ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
quantizedMaxPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
       (padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
     Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
     batchSize],
 ConvSideCheck
   inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1
   (Snd kernelSize)
   (Snd stride)
   (Snd padding)
   outputSize1) =>
Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
     device dtype '[batchSize, channelSize, outputSize0, outputSize1]
quantizedMaxPool2d Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> IO (ForeignPtr Tensor)
ATen.Managed.quantized_max_pool2d_tllll
      Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
      ([Int
1, Int
1] :: [Int])

-- | maxPool3d
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = maxPool3d @'(1,1,1) @'(1,1,1) @'(0,0,0) (ones :: CPUTensor 'D.Float '[1,3,4,5,6])
-- >>> shape t
-- [1,3,4,5,6]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4, 5, 6]
maxPool3d ::
  forall
    kernelSize
    stride
    padding
    channelSize
    inputSize0
    inputSize1
    inputSize2
    batchSize
    outputSize0
    outputSize1
    outputSize2
    dtype
    device.
  ( All
      KnownNat
      '[ Fst3 kernelSize,
         Snd3 kernelSize,
         Trd3 kernelSize,
         Fst3 stride,
         Snd3 stride,
         Trd3 stride,
         Fst3 padding,
         Snd3 padding,
         Trd3 padding,
         channelSize,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize
       ],
    ConvSideCheck inputSize0 (Fst3 kernelSize) (Fst3 stride) (Fst3 padding) outputSize0,
    ConvSideCheck inputSize1 (Snd3 kernelSize) (Snd3 stride) (Snd3 padding) outputSize1,
    ConvSideCheck inputSize2 (Trd3 kernelSize) (Trd3 stride) (Trd3 padding) outputSize2
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
maxPool3d :: forall (kernelSize :: (Nat, Nat, Nat)) (stride :: (Nat, Nat, Nat))
       (padding :: (Nat, Nat, Nat)) (channelSize :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
       (outputSize2 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst3 kernelSize, Snd3 kernelSize, Trd3 kernelSize, Fst3 stride,
     Snd3 stride, Trd3 stride, Fst3 padding, Snd3 padding, Trd3 padding,
     channelSize, inputSize0, inputSize1, inputSize2, batchSize],
 ConvSideCheck
   inputSize0
   (Fst3 kernelSize)
   (Fst3 stride)
   (Fst3 padding)
   outputSize0,
 ConvSideCheck
   inputSize1
   (Snd3 kernelSize)
   (Snd3 stride)
   (Snd3 padding)
   outputSize1,
 ConvSideCheck
   inputSize2
   (Trd3 kernelSize)
   (Trd3 stride)
   (Trd3 padding)
   outputSize2) =>
Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     device
     dtype
     '[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
maxPool3d Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.max_pool3d_tllllb
      Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
      ( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 kernelSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 kernelSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 kernelSize)
        ] ::
          [Int]
      )
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
      ([Int
1, Int
1, Int
1] :: [Int])
      Bool
False

-- | maskedFill
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = ones :: CPUTensor 'D.Float '[2, 1, 3]
-- >>> m = fromJust [[False], [True], [False]] :: CPUTensor 'D.Bool '[3, 1]
-- >>> t' = maskedFill @Float m 0.5 t
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 3, 3]
-- >>> dtype &&& shape &&& (\u -> D.asValue (toDynamic u) :: [[[Float]]]) $ t'
-- (Float,([2,3,3],[[[1.0,1.0,1.0],[0.5,0.5,0.5],[1.0,1.0,1.0]],[[1.0,1.0,1.0],[0.5,0.5,0.5],[1.0,1.0,1.0]]]))
maskedFill ::
  forall a shape shape' shape'' dtype device.
  (D.Scalar a, shape'' ~ Broadcast shape shape') =>
  -- | mask
  Tensor device 'D.Bool shape' ->
  -- | fill value
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape''
maskedFill :: forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool shape'
mask a
value Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.masked_fill_tts Tensor device dtype shape
input Tensor device 'Bool shape'
mask a
value

-- mkldnn_convolution :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> Int -> Tensor device dtype shape
-- mkldnn_convolution _input _weight _bias _padding _stride _dilation _groups = unsafePerformIO $ (ATen.cast7 ATen.Managed.mkldnn_convolution_tttllll) _input _weight _bias _padding _stride _dilation _groups

-- miopen_batch_norm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Double -> Double -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- miopen_batch_norm _input _weight _bias _running_mean _running_var _training _exponential_average_factor _epsilon = unsafePerformIO $ (ATen.cast8 ATen.Managed.miopen_batch_norm_tttttbdd) _input _weight _bias _running_mean _running_var _training _exponential_average_factor _epsilon

-- miopen_convolution :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> Int -> Bool -> Bool -> Tensor device dtype shape
-- miopen_convolution _input _weight _bias _padding _stride _dilation _groups _benchmark _deterministic = unsafePerformIO $ (ATen.cast9 ATen.Managed.miopen_convolution_tttllllbb) _input _weight _bias _padding _stride _dilation _groups _benchmark _deterministic

-- miopen_convolution_transpose :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> [Int] -> Int -> Bool -> Bool -> Tensor device dtype shape
-- miopen_convolution_transpose _input _weight _bias _padding _output_padding _stride _dilation _groups _benchmark _deterministic = unsafePerformIO $ (ATen.ATen.cast10 ATen.Managed.miopen_convolution_transpose_tttlllllbb) _input _weight _bias _padding _output_padding _stride _dilation _groups _benchmark _deterministic

-- miopen_depthwise_convolution :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> [Int] -> Int -> Bool -> Bool -> Tensor device dtype shape
-- miopen_depthwise_convolution _input _weight _bias _padding _stride _dilation _groups _benchmark _deterministic = unsafePerformIO $ (ATen.cast9 ATen.Managed.miopen_depthwise_convolution_tttllllbb) _input _weight _bias _padding _stride _dilation _groups _benchmark _deterministic

-- miopen_rnn :: Tensor device dtype shape -> [Tensor device dtype shape] -> Int -> Tensor device dtype shape -> Tensor device dtype shape -> Int -> Int -> Int -> Bool -> Double -> Bool -> Bool -> [Int] -> Tensor device dtype shape -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- miopen_rnn _input _weight _weight_stride0 _hx _cx _mode _hidden_size _num_layers _batch_first _dropout _train _bidirectional _batch_sizes _dropout_state = unsafePerformIO $ (ATen.ATen.cast14 ATen.Managed.miopen_rnn_tllttlllbdbblt) _input _weight _weight_stride0 _hx _cx _mode _hidden_size _num_layers _batch_first _dropout _train _bidirectional _batch_sizes _dropout_state

-- | matrix-matrix multiplication
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ mm (ones :: CPUTensor 'D.Float '[3,2]) (zeros :: CPUTensor 'D.Float '[2,4])
-- (Float,[3,4])
mm ::
  forall n k m dtype device.
  -- | first input matrix
  Tensor device dtype '[n, k] ->
  -- | second input matrix
  Tensor device dtype '[k, m] ->
  -- | output matrix
  Tensor device dtype '[n, m]
mm :: forall (n :: Nat) (k :: Nat) (m :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[n, k]
-> Tensor device dtype '[k, m] -> Tensor device dtype '[n, m]
mm Tensor device dtype '[n, k]
a Tensor device dtype '[k, m]
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mm_tt Tensor device dtype '[n, k]
a Tensor device dtype '[k, m]
b

-- | matrix-vector multiplication
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ mv (ones :: CPUTensor 'D.Float '[3,2]) (zeros :: CPUTensor 'D.Float '[2])
-- (Float,[3])
mv ::
  forall n m dtype device.
  -- | input matrix
  Tensor device dtype '[n, m] ->
  -- | input vector
  Tensor device dtype '[m] ->
  -- | output vector
  Tensor device dtype '[n]
mv :: forall (n :: Nat) (m :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[n, m]
-> Tensor device dtype '[m] -> Tensor device dtype '[n]
mv Tensor device dtype '[n, m]
input Tensor device dtype '[m]
vec = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mv_tt Tensor device dtype '[n, m]
input Tensor device dtype '[m]
vec

-- mvlgamma :: Tensor device dtype shape -> Int -> Tensor device dtype shape
-- mvlgamma _input _p = unsafePerformIO $ (ATen.cast2 ATen.Managed.mvlgamma_tl) _input _p

type family
  NarrowCheck
    (mbCurrent :: Maybe Nat)
    (mbUpdated :: Maybe [Nat])
    (shape :: [Nat])
    (dim :: Nat)
    (start :: Nat)
    (length :: Nat) ::
    [Nat]
  where
  NarrowCheck Nothing _ sh d _ _ = DimOutOfBound sh d
  NarrowCheck (Just c) Nothing sh d s l = DimOutOfBound sh d
  NarrowCheck _ (Just r) _ _ _ _ = r

type family Narrow' (dim :: Nat) (shape :: [Nat]) (current :: Maybe Nat) (start :: Nat) (length :: Nat) :: Maybe [Nat] where
  Narrow' d sh (Just c) s l =
    If
      ((s + l) <=? c)
      (ReplaceDim d sh l)
      ( TypeError
          ( Text "The end of the requested narrow segment "
              :<>: ShowType (s + l)
              :<>: Text " would be larger than current size "
              :<>: ShowType c
              :<>: Text " at dimension "
              :<>: ShowType d
          )
      )
  Narrow' d sh Nothing s l =
    TypeError
      ( Text "Requested narrow dimension "
          :<>: ShowType d
          :<>: Text " doesnt exist in "
          :<>: ShowType sh
      )

type family Narrow (shape :: [Nat]) (dim :: Nat) (start :: Nat) (length :: Nat) :: [Nat] where
  Narrow shape dim start length =
    NarrowCheck (ExtractDim dim shape) (Narrow' dim shape (ExtractDim dim shape) start length) shape dim start length

-- | "Narrow" a tensor by returning a tensor that is a slice from 'start' of length 'length' along 'dim'
--
-- >>> dtype &&& shape $ narrow @0 @0 @2 (ones :: CPUTensor 'D.Float '[3,3,3])
-- (Float,[2,3,3])
-- >>> dtype &&& shape $ narrow @1 @1 @2 (ones :: CPUTensor 'D.Half '[3,3,3])
-- (Half,[3,2,3])
-- >>> dtype &&& shape $ narrow @1 @1 @2 (ones :: CPUTensor 'D.Bool '[3,3,3])
-- (Bool,[3,2,3])
narrow ::
  forall dim start length shape mbSize mbNewShape dtype device.
  ( All KnownNat '[dim, start, length],
    All KnownNat shape
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype (Narrow shape dim start length)
narrow :: forall {k} {k} (dim :: Nat) (start :: Nat) (length :: Nat)
       (shape :: [Nat]) (mbSize :: k) (mbNewShape :: k) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[dim, start, length], All KnownNat shape) =>
Tensor device dtype shape
-> Tensor device dtype (Narrow shape dim start length)
narrow Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.narrow_tlll) Tensor device dtype shape
_input Int
_dim Int
_start Int
_length
  where
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim
    _start :: Int
_start = forall (n :: Nat). KnownNat n => Int
natValI @start
    _length :: Int
_length = forall (n :: Nat). KnownNat n => Int
natValI @length

-- native_batch_norm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Double -> Double -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- native_batch_norm _input _weight _bias _running_mean _running_var _training _momentum _eps = unsafePerformIO $ (ATen.cast8 ATen.Managed.native_batch_norm_tttttbdd) _input _weight _bias _running_mean _running_var _training _momentum _eps

-- batch_norm_stats :: Tensor device dtype shape -> Double -> (Tensor device dtype shape,Tensor device dtype shape)
-- batch_norm_stats _input _eps = unsafePerformIO $ (ATen.cast2 ATen.Managed.batch_norm_stats_td) _input _eps

-- batch_norm_elemt :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> Tensor device dtype shape
-- batch_norm_elemt _input _weight _bias _mean _invstd _eps = unsafePerformIO $ (ATen.cast6 ATen.Managed.batch_norm_elemt_tttttd) _input _weight _bias _mean _invstd _eps

-- batch_norm_gather_stats :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> Double -> Int -> (Tensor device dtype shape,Tensor device dtype shape)
-- batch_norm_gather_stats _input _mean _invstd _running_mean _running_var _momentum _eps _count = unsafePerformIO $ (ATen.cast8 ATen.Managed.batch_norm_gather_stats_tttttddl) _input _mean _invstd _running_mean _running_var _momentum _eps _count

-- batch_norm_gather_stats_with_counts :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> Double -> [Int] -> (Tensor device dtype shape,Tensor device dtype shape)
-- batch_norm_gather_stats_with_counts _input _mean _invstd _running_mean _running_var _momentum _eps _counts = unsafePerformIO $ (ATen.cast8 ATen.Managed.batch_norm_gather_stats_with_counts_tttttddl) _input _mean _invstd _running_mean _running_var _momentum _eps _counts

-- batch_norm_update_stats :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> (Tensor device dtype shape,Tensor device dtype shape)
-- batch_norm_update_stats _input _running_mean _running_var _momentum = unsafePerformIO $ (ATen.cast4 ATen.Managed.batch_norm_update_stats_tttd) _input _running_mean _running_var _momentum

-- | onesLike
--
-- >>> dtype &&& shape $ onesLike (ones :: CPUTensor 'D.Float '[3,4,5])
-- (Float,[3,4,5])
onesLike ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
onesLike :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
onesLike Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.ones_like_t Tensor device dtype shape
input

-- pairwise_distance :: Tensor device dtype shape -> Tensor device dtype shape -> Double -> Double -> Bool -> Tensor device dtype shape
-- pairwise_distance _x1 _x2 _p _eps _keepdim = unsafePerformIO $ (ATen.cast5 ATen.Managed.pairwise_distance_ttddb) _x1 _x2 _p _eps _keepdim

-- cdist :: Tensor device dtype shape -> Tensor device dtype shape -> Double -> Tensor device dtype shape
-- cdist _x1 _x2 _p = unsafePerformIO $ (ATen.cast3 ATen.Managed.cdist_ttd) _x1 _x2 _p

-- pdist :: Tensor device dtype shape -> Double -> Tensor device dtype shape
-- pdist _input _p = unsafePerformIO $ (ATen.cast2 ATen.Managed.pdist_td) _input _p

-- cosine_similarity :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Double -> Tensor device dtype shape
-- cosine_similarity _x1 _x2 _dim _eps = unsafePerformIO $ (ATen.cast4 ATen.Managed.cosine_similarity_ttld) _x1 _x2 _dim _eps

-- pixel_shuffle :: Tensor device dtype shape -> Int -> Tensor device dtype shape
-- pixel_shuffle _input _upscale_factor = unsafePerformIO $ (ATen.cast2 ATen.Managed.pixel_shuffle_tl) _input _upscale_factor

-- pin_memory :: Tensor device dtype shape -> Tensor device dtype shape
-- pin_memory _input = unsafePerformIO $ (ATen.cast1 ATen.Managed.pin_memory_t) _input

-- pinverse :: Tensor device dtype shape -> Double -> Tensor device dtype shape
-- pinverse _input _rcond = unsafePerformIO $ (ATen.cast2 ATen.Managed.pinverse_td) _input _rcond

-- poisson_nll_loss :: Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Bool -> Double -> Int -> Tensor device dtype shape
-- poisson_nll_loss _input _target _log_input _full _eps _reduction = unsafePerformIO $ (ATen.cast6 ATen.Managed.poisson_nll_loss_ttbbdl) _input _target _log_input _full _eps _reduction

-- | randLike
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t <- randLike (ones :: CPUTensor 'D.Float '[3,4,5])
-- >>> dtype &&& shape $ t
-- (Float,[3,4,5])
randLike ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  IO (Tensor device dtype shape)
randLike :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
randLike = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.rand_like_t

-- | randnLike
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t <- randnLike (ones :: CPUTensor 'D.Float '[3,4,5])
-- >>> dtype &&& shape $ t
-- (Float,[3,4,5])
randnLike ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  IO (Tensor device dtype shape)
randnLike :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
randnLike = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.randn_like_t

-- | reciprocal
--
-- >>> dtype &&& shape $ reciprocal (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
reciprocal ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
reciprocal :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
reciprocal Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.reciprocal_t) Tensor device dtype shape
_input

-- | negate
-- TODO: probably not defined for `D.Bool` tensors
--
-- >>> dtype &&& shape $ neg (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
neg ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
neg :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
neg Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.neg_t Tensor device dtype shape
input

-- | round
-- TODO: probably only defined for floating point tensors
--
-- >>> dtype &&& shape $ round (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
round ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
round :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
round Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.round_t Tensor device dtype shape
input

-- | prelu activation function
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ prelu (ones :: CPUTensor 'D.Float '[]) (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
prelu ::
  forall shape dtype device.
  -- | weight
  Tensor device dtype '[] ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
prelu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> Tensor device dtype shape -> Tensor device dtype shape
prelu Tensor device dtype '[]
weight Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.prelu_tt Tensor device dtype shape
input Tensor device dtype '[]
weight

type family GeluDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  GeluDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  GeluDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
    ( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
      DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
    )
  GeluDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | gelu activation function
--
-- >>> dtype &&& shape $ round (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
gelu ::
  forall shape dtype device.
  (GeluDTypeIsValid device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
gelu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
GeluDTypeIsValid device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
gelu Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.gelu_t Tensor device dtype shape
input

-- hardshrink :: Tensor device dtype shape -> Float -> Tensor device dtype shape
-- hardshrink _input _lambd = unsafePerformIO $ (ATen.cast2 ATen.Managed.hardshrink_ts) _input _lambd

-- | rsqrt
--
-- >>> dtype &&& shape $ rsqrt (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
rsqrt ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
rsqrt :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
rsqrt Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.rsqrt_t Tensor device dtype shape
input

-- | celu activation function
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ celu 3.0 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
celu ::
  forall shape dtype device.
  -- | alpha
  Float ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
celu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Float -> Tensor device dtype shape -> Tensor device dtype shape
celu Float
alpha Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.celu_ts Tensor device dtype shape
input Float
alpha

-- slice :: Tensor device dtype shape -> Int -> Int -> Int -> Int -> Tensor device dtype shape
-- slice _input _dim _start _end _step = unsafePerformIO $ (ATen.cast5 ATen.Managed.slice_tllll) _input _dim _start _end _step

-- slogdet :: Tensor device dtype shape -> (Tensor device dtype shape,Tensor device dtype shape)
-- slogdet _input = unsafePerformIO $ (ATen.cast1 ATen.Managed.slogdet_t) _input

-- smm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- smm _input _mat2 = unsafePerformIO $ (ATen.cast2 ATen.Managed.smm_tt) _input _mat2

-- split :: Tensor device dtype shape -> Int -> Int -> [Tensor device dtype shape]
-- split _input _split_size _dim = unsafePerformIO $ (ATen.cast3 ATen.Managed.split_tll) _input _split_size _dim

-- split_with_sizes :: Tensor device dtype shape -> [Int] -> Int -> [Tensor device dtype shape]
-- split_with_sizes _input _split_sizes _dim = unsafePerformIO $ (ATen.cast3 ATen.Managed.split_with_sizes_tll) _input _split_sizes _dim

-- sspaddmm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Tensor device dtype shape
-- sspaddmm _input _mat1 _mat2 _beta _alpha = unsafePerformIO $ (ATen.cast5 ATen.Managed.sspaddmm_tttss) _input _mat1 _mat2 _beta _alpha

type family StackImpl (dim :: Nat) (tensors :: [a]) (count :: Nat) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
  StackImpl dim '[] count = Nothing
  StackImpl dim (Tensor device dtype shape ': '[]) count = MaybeTriple (ComputeStackShape shape dim count) (Just dtype) (Just device)
  StackImpl dim (Tensor device dtype shape ': Tensor device dtype shape ': tensors) count = StackImpl dim (Tensor device dtype shape ': tensors) (count + 1)
  StackImpl _ _ _ = Nothing

type family MaybePair (a' :: Maybe a) (b' :: Maybe b) :: Maybe (a, b) where
  MaybePair Nothing _ = Nothing
  MaybePair _ Nothing = Nothing
  MaybePair (Just a') (Just b') = Just '(a', b')

type family MaybeTriple (a' :: Maybe a) (b' :: Maybe b) (c' :: Maybe c) :: Maybe (a, b, c) where
  MaybeTriple Nothing _ _ = Nothing
  MaybeTriple _ Nothing _ = Nothing
  MaybeTriple _ _ Nothing = Nothing
  MaybeTriple (Just a') (Just b') (Just c') = Just '(a', b', c')

type family ComputeStackShape (shape :: [Nat]) (dim :: Nat) (count :: Nat) :: Maybe [Nat] where
  ComputeStackShape _ _ 0 = Nothing
  ComputeStackShape xs 0 count = Just (count ': xs)
  ComputeStackShape (x ': xs) dim count = AppendToMaybe x (ComputeStackShape xs (dim - 1) count)
  ComputeStackShape '[] _ _ = Nothing

type family StackCheck (res :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: ([Nat], D.DType, (D.DeviceType, Nat)) where
  StackCheck 'Nothing = TypeError (Text "Stacking impossible.")
  StackCheck ('Just '(shape, dtype, device)) = '(shape, dtype, device)

-- | Stack
--
-- >>> type Ty = Stack 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[1], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Stack 0 '[Tensor '( 'D.CPU, 0) 'D.Float '[2,2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[1, 2, 2], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Stack 1 '[Tensor '( 'D.CPU, 0) 'D.Float '[2,2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[2, 1, 2], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Stack 2 '[Tensor '( 'D.CPU, 0) 'D.Float '[2,2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[2, 2, 1], 'D.Float, '( 'D.CPU, 0))
-- >>> type Ty = Stack 2 '[Tensor '( 'D.CPU, 0) 'D.Float '[2,2], Tensor '( 'D.CPU, 0) 'D.Float '[2,2], Tensor '( 'D.CPU, 0) 'D.Float '[2,2]]
-- >>> :kind! Ty
-- Ty :: ([Natural], D.DType, (D.DeviceType, Natural))
-- = '( '[2, 2, 3], 'D.Float, '( 'D.CPU, 0))
type Stack dim tensors = StackCheck (StackImpl dim tensors 1)

-- | stack
--
-- >>> t = ones :: CPUTensor 'D.Float '[]
-- >>> t' = stack @0 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[1]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [Float]) $ t'
-- (Float,([1],[1.0]))
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> t' = stack @0 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 2, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([1,2,2],[[[1.0,1.0],[1.0,1.0]]]))
-- >>> t' = stack @1 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([2,1,2],[[[1.0,1.0]],[[1.0,1.0]]]))
-- >>> t' = stack @2 (t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2, 1]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([2,2,1],[[[1.0],[1.0]],[[1.0],[1.0]]]))
-- >>> t' = stack @2 (t :. t :. t :. HNil)
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2, 3]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([2,2,3],[[[1.0,1.0,1.0],[1.0,1.0,1.0]],[[1.0,1.0,1.0],[1.0,1.0,1.0]]]))
stack ::
  forall dim shape dtype device tensors.
  ( KnownNat dim,
    '(shape, dtype, device) ~ Stack dim tensors,
    ATen.Castable (HList tensors) [D.ATenTensor]
  ) =>
  -- | input list of tensors
  HList tensors ->
  -- | output
  Tensor device dtype shape
stack :: forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
stack HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.stack_ll HList tensors
tensors (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)

vecStack ::
  forall dim n shape dtype device.
  (KnownNat dim, KnownNat n) =>
  -- | Input list of tensors
  Vector n (Tensor device dtype shape) ->
  -- | Output list of tensors
  Tensor device dtype (Insert dim n shape)
vecStack :: forall (dim :: Nat) (n :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
vecStack Vector n (Tensor device dtype shape)
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.stack_ll Vector n (Tensor device dtype shape)
tensors (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)

-- stft :: Tensor device dtype shape -> Int -> Int -> Int -> Tensor device dtype shape -> Bool -> Bool -> Tensor device dtype shape
-- stft _input _n_fft _hop_length _win_length _window _normalized _onesided = unsafePerformIO $ (ATen.cast7 ATen.Managed.stft_tllltbb) _input _n_fft _hop_length _win_length _window _normalized _onesided

-- stride :: Tensor device dtype shape -> Int -> Int
-- stride _input _dim = unsafePerformIO $ (ATen.cast2 ATen.Managed.stride_tl) _input _dim

-- | t
--
-- dtype &&& shape $ t ones :: CPUTensor 'D.Float '[3,2]
-- (Float,[3,2])
t ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
t :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
t Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.t_t) Tensor device dtype shape
_input

-- | tan
--
-- >>> dtype &&& shape $ tan (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
tan ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
tan :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tan Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.tan_t Tensor device dtype shape
input

-- tensordot :: Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> [Int] -> Tensor device dtype shape
-- tensordot _input _other _dims_input _dims_other = unsafePerformIO $ (ATen.cast4 ATen.Managed.tensordot_ttll) _input _other _dims_input _dims_other

-- threshold :: Tensor device dtype shape -> Float -> Float -> Tensor device dtype shape
-- threshold _input _threshold _value = unsafePerformIO $ (ATen.cast3 ATen.Managed.threshold_tss) _input _threshold _value

-- one_hot :: Tensor device dtype shape -> Int -> Tensor device dtype shape
-- one_hot _input _num_classes = unsafePerformIO $ (ATen.cast2 ATen.Managed.one_hot_tl) _input _num_classes

-- flip :: Tensor device dtype shape -> [Int] -> Tensor device dtype shape
-- flip _input _dims = unsafePerformIO $ (ATen.cast2 ATen.Managed.flip_tl) _input _dims

-- roll :: Tensor device dtype shape -> Int -> Int -> Tensor device dtype shape
-- roll _input _shifts _dims = unsafePerformIO $ (ATen.cast3 ATen.Managed.roll_tll) _input _shifts _dims

-- rot90 :: Tensor device dtype shape -> Int -> [Int] -> Tensor device dtype shape
-- rot90 _input _k _dims = unsafePerformIO $ (ATen.cast3 ATen.Managed.rot90_tll) _input _k _dims

-- triplet_margin_loss :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Double -> Double -> Double -> Bool -> Int -> Tensor device dtype shape
-- triplet_margin_loss _anchor _positive _negative _margin _p _eps _swap _reduction = unsafePerformIO $ (ATen.cast8 ATen.Managed.triplet_margin_loss_tttdddbl) _anchor _positive _negative _margin _p _eps _swap _reduction

-- | trunc
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ trunc (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
trunc ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
trunc :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
trunc Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.trunc_t Tensor device dtype shape
input

-- unique_dim :: Tensor device dtype shape -> Int -> Bool -> Bool -> Bool -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- unique_dim _input _dim _sorted _return_inverse _return_counts = unsafePerformIO $ (ATen.cast5 ATen.Managed.unique_dim_tlbbb) _input _dim _sorted _return_inverse _return_counts

-- unique_consecutive :: Tensor device dtype shape -> Bool -> Bool -> Int -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- unique_consecutive _input _return_inverse _return_counts _dim = unsafePerformIO $ (ATen.cast4 ATen.Managed.unique_consecutive_tbbl) _input _return_inverse _return_counts _dim

-- unique_dim_consecutive :: Tensor device dtype shape -> Int -> Bool -> Bool -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- unique_dim_consecutive _input _dim _return_inverse _return_counts = unsafePerformIO $ (ATen.cast4 ATen.Managed.unique_dim_consecutive_tlbb) _input _dim _return_inverse _return_counts

-- | UnsqueezeImpl
--
-- >>> :kind! UnsqueezeImpl '[4] 0
-- UnsqueezeImpl '[4] 0 :: Maybe [Natural]
-- = 'Just '[1, 4]
-- >>> :kind! UnsqueezeImpl '[4] 1
-- UnsqueezeImpl '[4] 1 :: Maybe [Natural]
-- = 'Just '[4, 1]
-- >>> :kind! UnsqueezeImpl '[4] 2
-- UnsqueezeImpl '[4] 2 :: Maybe [Natural]
-- = 'Nothing
type family UnsqueezeImpl (shape :: [a]) (dim :: Nat) :: Maybe [a] where
  UnsqueezeImpl xs 0 = Just (1 ': xs)
  UnsqueezeImpl (x ': xs) dim = AppendToMaybe x (UnsqueezeImpl xs (dim - 1))
  UnsqueezeImpl '[] _ = Nothing

type family UnsqueezeCheck (shape :: [a]) (dim :: Nat) (result :: Maybe [a]) :: [a] where
  UnsqueezeCheck shape dim Nothing =
    TypeError
      ( Text "Cannot unsqueeze the tensor since the specified dimension "
          :<>: ShowType dim
          :<>: Text " is too large (the tensor is only "
          :<>: ShowType (ListLength shape)
          :<>: Text "D)"
      )
  UnsqueezeCheck _ _ (Just shape') = shape'

type Unsqueeze shape dim = UnsqueezeCheck shape dim (UnsqueezeImpl shape dim)

-- | unsqueeze
--
-- >>> t = fromJust [1, 2, 3, 4] :: CPUTensor 'D.Int64 '[4]
-- >>> t' = unsqueeze @0 t
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Int64 '[1, 4]
-- >>> dtype &&& shape &&& (\u -> D.asValue (toDynamic u) :: [[Int]]) $ t'
-- (Int64,([1,4],[[1,2,3,4]]))
-- >>> t'' = unsqueeze @1 t
-- >>> :type t''
-- t'' :: Tensor '( 'D.CPU, 0) 'D.Int64 '[4, 1]
-- >>> dtype &&& shape &&& (\u -> D.asValue (toDynamic u) :: [[Int]]) $ t''
-- (Int64,([4,1],[[1],[2],[3],[4]]))
unsqueeze ::
  forall dim shape shape' dtype device.
  (KnownNat dim, shape' ~ Unsqueeze shape dim) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
unsqueeze :: forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.unsqueeze_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)

type family SqueezeAll (shape :: [Nat]) :: [Nat] where
  SqueezeAll '[] = '[]
  SqueezeAll (1 ': xs) = SqueezeAll xs
  SqueezeAll (x ': xs) = x ': SqueezeAll xs

-- | squeeze all dimensions
--
-- >>> dtype &&& shape $ squeezeAll (ones :: CPUTensor 'D.Float '[2,1,2,1,2])
-- (Float,[2,2,2])
-- >>> squeezeAll (ones :: CPUTensor 'D.Float '[2,1,2,1,2])
-- Tensor Float [2,2,2] [[[ 1.0000   ,  1.0000   ],
--                        [ 1.0000   ,  1.0000   ]],
--                       [[ 1.0000   ,  1.0000   ],
--                        [ 1.0000   ,  1.0000   ]]]
squeezeAll ::
  forall shape shape' dtype device.
  (shape' ~ SqueezeAll shape) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
squeezeAll :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ SqueezeAll shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
squeezeAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.squeeze_t Tensor device dtype shape
input

type family SqueezeDimImpl (shape :: [a]) (dim :: Nat) :: Maybe [a] where
  SqueezeDimImpl (1 ': xs) 0 = Just xs
  SqueezeDimImpl _ 0 = Nothing
  SqueezeDimImpl (x ': xs) dim = AppendToMaybe x (SqueezeDimImpl xs (dim - 1))
  SqueezeDimImpl _ _ = Nothing

type family SqueezeDimCheck (shape :: [a]) (dim :: Nat) (result :: Maybe [a]) :: [a] where
  SqueezeDimCheck shape dim Nothing = TypeError (Text "The tensor cannot be squeezed at the specified dimension " :<>: ShowType dim)
  SqueezeDimCheck _ _ ('Just shape') = shape'

-- | Calculate the output shape of a squeeze along a given dimension
--
-- >>> :kind! SqueezeDim '[2,1,2] 1
-- SqueezeDim '[2,1,2] 1 :: [Natural]
-- = '[2, 2]
type SqueezeDim shape dim = SqueezeDimCheck shape dim (SqueezeDimImpl shape dim)

-- | squeeze a particular dimension
--
-- >>> dtype &&& shape $ squeezeDim @1 (ones :: CPUTensor 'D.Float '[2,1,2,1,2])
-- (Float,[2,2,1,2])
-- >>> dtype &&& shape $ squeezeDim @3 (ones :: CPUTensor 'D.Float '[2,1,2,1,2])
-- (Float,[2,1,2,2])
squeezeDim ::
  forall dim shape shape' dtype device.
  (KnownNat dim, shape' ~ SqueezeDim shape dim) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
squeezeDim :: forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
squeezeDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.squeeze_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)

-- where' :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- where' _condition _input _other = unsafePerformIO $ (ATen.cast3 ATen.Managed.where_ttt) _condition _input _other

-- where_ :: Tensor device dtype shape -> [Tensor device dtype shape]
-- where_ _condition = unsafePerformIO $ (ATen.cast1 ATen.Managed.where_t) _condition

-- norm_except_dim :: Tensor device dtype shape -> Int -> Int -> Tensor device dtype shape
-- norm_except_dim _v _pow _dim = unsafePerformIO $ (ATen.cast3 ATen.Managed.norm_except_dim_tll) _v _pow _dim

-- | zerosLike
--
-- >>> dtype &&& shape $ zerosLike (ones :: CPUTensor 'D.Float '[3,4,5])
-- (Float,[3,4,5])
zerosLike ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
zerosLike :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
zerosLike Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.zeros_like_t Tensor device dtype shape
input

-- native_norm :: Tensor device dtype shape -> Float -> Tensor device dtype shape
-- native_norm _input _p = unsafePerformIO $ (ATen.cast2 ATen.Managed.native_norm_ts) _input _p

-- | clone
--
-- >>> t <- clone (ones :: CPUTensor 'D.Float '[3,2])
-- >>> dtype &&& shape $ t
-- (Float,[3,2])
clone ::
  forall shape dtype device.
  Tensor device dtype shape ->
  IO (Tensor device dtype shape)
clone :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
clone Tensor device dtype shape
input = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.clone_t Tensor device dtype shape
input

-- s_native_addmm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Tensor device dtype shape
-- s_native_addmm _input _mat1 _mat2 _beta _alpha = unsafePerformIO $ (ATen.cast5 ATen.Managed.s_native_addmm_tttss) _input _mat1 _mat2 _beta _alpha

-- | addmm
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: can we use D.Scalar here for beta and alpha?
--
-- >>> t = addmm 1 1 (ones :: CPUTensor 'D.Float '[3,2]) (zeros :: CPUTensor 'D.Float '[2,4]) (ones :: CPUTensor 'D.Float '[])
-- >>> dtype &&& shape $ t
-- (Float,[3,4])
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[3, 4]
addmm ::
  forall shape' shape n k m dtype device.
  ( All KnownNat '[n, k, m],
    shape' ~ Broadcast shape '[n, m]
  ) =>
  -- | beta
  Float ->
  -- | alpha
  Float ->
  -- | first input matrix
  Tensor device dtype '[n, k] ->
  -- | second input matrix
  Tensor device dtype '[k, m] ->
  -- | input tensor
  Tensor device dtype shape ->
  -- | output tensor
  Tensor device dtype shape'
addmm :: forall (shape' :: [Nat]) (shape :: [Nat]) (n :: Nat) (k :: Nat)
       (m :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(All KnownNat '[n, k, m], shape' ~ Broadcast shape '[n, m]) =>
Float
-> Float
-> Tensor device dtype '[n, k]
-> Tensor device dtype '[k, m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
addmm Float
beta Float
alpha Tensor device dtype '[n, k]
mat1 Tensor device dtype '[k, m]
mat2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.addmm_tttss Tensor device dtype shape
input Tensor device dtype '[n, k]
mat1 Tensor device dtype '[k, m]
mat2 Float
beta Float
alpha

-- hspmm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- hspmm _mat1 _mat2 = unsafePerformIO $ (ATen.cast2 ATen.Managed.hspmm_tt) _mat1 _mat2

-- | numel
-- TODO: since this is decidable at compile time, this should probably be calculated from the tensor type instead
numel ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Int
numel :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Int
numel Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO Int64
ATen.Managed.tensor_numel Tensor device dtype shape
input

-- unbind :: Tensor device dtype shape -> Int -> [Tensor device dtype shape]
-- unbind _input _dim = unsafePerformIO $ (ATen.cast2 ATen.Managed.unbind_tl) _input _dim

-- mkldnn_reorder_conv2d_weight :: Tensor device dtype shape -> (Int,Int) -> (Int,Int) -> (Int,Int) -> Int -> Tensor device dtype shape
-- mkldnn_reorder_conv2d_weight _input _padding _stride _dilation _groups = unsafePerformIO $ (ATen.cast5 ATen.Managed.mkldnn_reorder_conv2d_weight_tllll) _input _padding _stride _dilation _groups

--quantize_linear :: Tensor device dtype shape -> Double -> Int -> DType -> Tensor device dtype shape
--quantize_linear _input _scale _zero_point _dtype = unsafePerformIO $ (ATen.cast4 ATen.Managed.quantize_linear_tdls) _input _scale _zero_point _dtype

--quantize_linear_per_channel :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> [Int] -> DType -> Tensor device dtype shape
--quantize_linear_per_channel _input _scales _zero_points _axis _dtype = unsafePerformIO $ (ATen.cast5 ATen.Managed.quantize_linear_per_channel_tttls) _input _scales _zero_points _axis _dtype

-- dequantize :: Tensor device dtype shape -> Tensor device dtype shape
-- dequantize _input = unsafePerformIO $ (ATen.cast1 ATen.Managed.dequantize_t) _input

-- | qScale
-- TODO: are there any restrictions on the dtype?
qScale ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Double
qScale :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Double
qScale Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CDouble
ATen.Managed.q_scale_t Tensor device dtype shape
input

-- | qZeroPoint
-- TODO: are there any restrictions on the dtype?
qZeroPoint ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Int
qZeroPoint :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Int
qZeroPoint Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO Int64
ATen.Managed.q_zero_point_t Tensor device dtype shape
input

-- int_repr :: Tensor device dtype shape -> Tensor device dtype shape
-- int_repr _input = unsafePerformIO $ (ATen.cast1 ATen.Managed.int_repr_t) _input

-- fake_quantize_per_tensor_affine :: Tensor device dtype shape -> Double -> Int -> Int -> Int -> Tensor device dtype shape
-- fake_quantize_per_tensor_affine _input _scale _zero_point _quant_min _quant_max = unsafePerformIO $ (ATen.cast5 ATen.Managed.fake_quantize_per_tensor_affine_tdlll) _input _scale _zero_point _quant_min _quant_max

-- meshgrid :: [Tensor device dtype shape] -> [Tensor device dtype shape]
-- meshgrid _tensors = unsafePerformIO $ (ATen.cast1 ATen.Managed.meshgrid_l) _tensors

-- cartesian_prod :: [Tensor device dtype shape] -> Tensor device dtype shape
-- cartesian_prod _tensors = unsafePerformIO $ (ATen.cast1 ATen.Managed.cartesian_prod_l) _tensors

-- combinations :: Tensor device dtype shape -> Int -> Bool -> Tensor device dtype shape
-- combinations _input _r _with_replacement = unsafePerformIO $ (ATen.cast3 ATen.Managed.combinations_tlb) _input _r _with_replacement

-- | The directional specification of a recurrent function
data RNNDirectionality
  = -- | Forward and backward along the sequential axis using independant parameters for each.
    Bidirectional
  | -- | Forward along the sequential axis.
    Unidirectional
  deriving (Int -> RNNDirectionality -> ShowS
[RNNDirectionality] -> ShowS
RNNDirectionality -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RNNDirectionality] -> ShowS
$cshowList :: [RNNDirectionality] -> ShowS
show :: RNNDirectionality -> String
$cshow :: RNNDirectionality -> String
showsPrec :: Int -> RNNDirectionality -> ShowS
$cshowsPrec :: Int -> RNNDirectionality -> ShowS
Show, forall x. Rep RNNDirectionality x -> RNNDirectionality
forall x. RNNDirectionality -> Rep RNNDirectionality x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RNNDirectionality x -> RNNDirectionality
$cfrom :: forall x. RNNDirectionality -> Rep RNNDirectionality x
Generic) -- TODO:  We could also have BidirectionalTied weights.

type family NumberOfDirections (directionality :: RNNDirectionality) :: Nat where
  NumberOfDirections Bidirectional = 2
  NumberOfDirections Unidirectional = 1

class KnownRNNDirectionality (directionality :: RNNDirectionality) where
  rnnBidirectional :: Bool

instance KnownRNNDirectionality Bidirectional where
  rnnBidirectional :: Bool
rnnBidirectional = Bool
True

instance KnownRNNDirectionality Unidirectional where
  rnnBidirectional :: Bool
rnnBidirectional = Bool
False

-- | Specification for the sequential axis of a recurrent function.
data RNNShapeOrder
  = -- | Input is of shape (Batch, Sequence, Features)
    BatchFirst
  | -- | Input is of shape (Sequence, Batch, Features)
    SequenceFirst
  deriving (Int -> RNNShapeOrder -> ShowS
[RNNShapeOrder] -> ShowS
RNNShapeOrder -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RNNShapeOrder] -> ShowS
$cshowList :: [RNNShapeOrder] -> ShowS
show :: RNNShapeOrder -> String
$cshow :: RNNShapeOrder -> String
showsPrec :: Int -> RNNShapeOrder -> ShowS
$cshowsPrec :: Int -> RNNShapeOrder -> ShowS
Show, forall x. Rep RNNShapeOrder x -> RNNShapeOrder
forall x. RNNShapeOrder -> Rep RNNShapeOrder x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RNNShapeOrder x -> RNNShapeOrder
$cfrom :: forall x. RNNShapeOrder -> Rep RNNShapeOrder x
Generic)

class KnownRNNShapeOrder (shapeOrder :: RNNShapeOrder) where
  rnnBatchFirst :: Bool

instance KnownRNNShapeOrder BatchFirst where
  rnnBatchFirst :: Bool
rnnBatchFirst = Bool
True

instance KnownRNNShapeOrder SequenceFirst where
  rnnBatchFirst :: Bool
rnnBatchFirst = Bool
False

type family RNNShape (shapeOrder :: RNNShapeOrder) (seqLen :: Nat) (batchSize :: Nat) (featureSize :: Nat) :: [Nat] where
  RNNShape BatchFirst seqLen batchSize featureSize = '[batchSize, seqLen, featureSize]
  RNNShape SequenceFirst seqLen batchSize featureSize = '[seqLen, batchSize, featureSize]

type LSTMWIShape hiddenSize inputSize = '[4 * hiddenSize, inputSize]

type LSTMWHShape hiddenSize inputSize = '[4 * hiddenSize, hiddenSize]

type LSTMBIShape hiddenSize inputSize = '[4 * hiddenSize]

type LSTMBHShape hiddenSize inputSize = '[4 * hiddenSize]

type family LSTMRImpl (inputSize :: Nat) (hiddenSize :: Nat) (numLayers :: Nat) (directionality :: RNNDirectionality) :: [[Nat]] where
  LSTMRImpl inputSize hiddenSize 1 'Unidirectional =
    '[ LSTMWIShape hiddenSize inputSize,
       LSTMWHShape hiddenSize inputSize,
       LSTMBIShape hiddenSize inputSize,
       LSTMBHShape hiddenSize inputSize
     ]
  LSTMRImpl inputSize hiddenSize numLayers 'Unidirectional =
    LSTMRImpl inputSize hiddenSize (numLayers - 1) 'Unidirectional
      ++ '[ LSTMWIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
            LSTMWHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
            LSTMBIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
            LSTMBHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional)
          ]
  LSTMRImpl inputSize hiddenSize 1 'Bidirectional =
    '[ LSTMWIShape hiddenSize inputSize,
       LSTMWHShape hiddenSize inputSize,
       LSTMBIShape hiddenSize inputSize,
       LSTMBHShape hiddenSize inputSize,
       LSTMWIShape hiddenSize inputSize,
       LSTMWHShape hiddenSize inputSize,
       LSTMBIShape hiddenSize inputSize,
       LSTMBHShape hiddenSize inputSize
     ]
  LSTMRImpl inputSize hiddenSize numLayers 'Bidirectional =
    LSTMRImpl inputSize hiddenSize (numLayers - 1) 'Bidirectional
      ++ '[ LSTMWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            LSTMWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            LSTMBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            LSTMBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            LSTMWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            LSTMWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            LSTMBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            LSTMBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional)
          ]

type family LSTMR' (shapes :: [[Nat]]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) :: [a] where
  LSTMR' '[] dtype device = '[]
  LSTMR' (shape ': shapes) dtype device = Tensor device dtype shape ': LSTMR' shapes dtype device

type LSTMR inputSize hiddenSize numLayers directionality dtype device = LSTMR' (LSTMRImpl inputSize hiddenSize numLayers directionality) dtype device

-- | lstm
-- Parameters for this ATen function are non-trivially provided.  See the
-- `Typed.NN.LSTM` module for doctests.
lstm ::
  forall
    shapeOrder
    directionality
    numLayers
    seqLen
    batchSize
    inputSize
    outputSize
    hiddenSize
    inputShape
    outputShape
    hxShape
    tensorParameters
    dtype
    device.
  ( KnownNat numLayers,
    KnownRNNShapeOrder shapeOrder,
    KnownRNNDirectionality directionality,
    outputSize ~ (hiddenSize * NumberOfDirections directionality),
    inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
    outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
    hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
    tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
    ATen.Castable (HList tensorParameters) [D.ATenTensor]
  ) =>
  HList tensorParameters ->
  Double ->
  Bool ->
  (Tensor device dtype hxShape, Tensor device dtype hxShape) ->
  Tensor device dtype inputShape ->
  ( Tensor device dtype outputShape,
    Tensor device dtype hxShape,
    Tensor device dtype hxShape
  )
lstm :: forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (inputSize :: Nat)
       (outputSize :: Nat) (hiddenSize :: Nat) (inputShape :: [Nat])
       (outputShape :: [Nat]) (hxShape :: [Nat]) (tensorParameters :: [k])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ForeignPtr Tensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstm HList tensorParameters
tensorParameters Double
dropoutProb Bool
dropoutOn (Tensor device dtype hxShape
cc, Tensor device dtype hxShape
hc) Tensor device dtype inputShape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
ATen.cast9
      ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.Managed.lstm_tllbldbbb
      Tensor device dtype inputShape
input
      [Tensor device dtype hxShape]
hx
      HList tensorParameters
tensorParameters
      Bool
hasBiases
      Int64
numLayers
      Double
dropoutProb
      Bool
dropoutOn
      (forall (directionality :: RNNDirectionality).
KnownRNNDirectionality directionality =>
Bool
rnnBidirectional @directionality)
      (forall (shapeOrder :: RNNShapeOrder).
KnownRNNShapeOrder shapeOrder =>
Bool
rnnBatchFirst @shapeOrder)
  where
    hasBiases :: Bool
hasBiases = Bool
True
    hx :: [Tensor device dtype hxShape]
    hx :: [Tensor device dtype hxShape]
hx = [Tensor device dtype hxShape
cc, Tensor device dtype hxShape
hc]
    numLayers :: I.Int64
    numLayers :: Int64
numLayers = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @numLayers

-- | lstmCell
--
-- >>> dtype &&& shape $ fst $ lstmCell (ones :: CPUTensor 'D.Float '[12,2]) (ones :: CPUTensor 'D.Float '[12,3]) (ones :: CPUTensor 'D.Float '[12]) (ones :: CPUTensor 'D.Float '[12]) ((ones :: CPUTensor 'D.Float '[2,3]), (ones :: CPUTensor 'D.Float '[2,3])) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,3])
lstmCell ::
  forall inputSize hiddenSize batchSize dtype device.
  Tensor device dtype '[4 * hiddenSize, inputSize] ->
  Tensor device dtype '[4 * hiddenSize, hiddenSize] ->
  Tensor device dtype '[4 * hiddenSize] ->
  Tensor device dtype '[4 * hiddenSize] ->
  ( Tensor device dtype '[batchSize, hiddenSize],
    Tensor device dtype '[batchSize, hiddenSize]
  ) ->
  Tensor device dtype '[batchSize, inputSize] ->
  ( Tensor device dtype '[batchSize, hiddenSize],
    Tensor device dtype '[batchSize, hiddenSize]
  )
lstmCell :: forall (inputSize :: Nat) (hiddenSize :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[4 * hiddenSize, inputSize]
-> Tensor device dtype '[4 * hiddenSize, hiddenSize]
-> Tensor device dtype '[4 * hiddenSize]
-> Tensor device dtype '[4 * hiddenSize]
-> (Tensor device dtype '[batchSize, hiddenSize],
    Tensor device dtype '[batchSize, hiddenSize])
-> Tensor device dtype '[batchSize, inputSize]
-> (Tensor device dtype '[batchSize, hiddenSize],
    Tensor device dtype '[batchSize, hiddenSize])
lstmCell Tensor device dtype '[4 * hiddenSize, inputSize]
wi Tensor device dtype '[4 * hiddenSize, hiddenSize]
wh Tensor device dtype '[4 * hiddenSize]
bi Tensor device dtype '[4 * hiddenSize]
bh (Tensor device dtype '[batchSize, hiddenSize]
cc, Tensor device dtype '[batchSize, hiddenSize]
hc) Tensor device dtype '[batchSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6 ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.lstm_cell_tltttt Tensor device dtype '[batchSize, inputSize]
input [Tensor device dtype '[batchSize, hiddenSize]]
hx Tensor device dtype '[4 * hiddenSize, inputSize]
wi Tensor device dtype '[4 * hiddenSize, hiddenSize]
wh Tensor device dtype '[4 * hiddenSize]
bi Tensor device dtype '[4 * hiddenSize]
bh
  where
    hx :: [Tensor device dtype '[batchSize, hiddenSize]]
hx = [Tensor device dtype '[batchSize, hiddenSize]
cc, Tensor device dtype '[batchSize, hiddenSize]
hc] :: [Tensor device dtype '[batchSize, hiddenSize]]

type GRUWIShape hiddenSize inputSize = '[3 * hiddenSize, inputSize]

type GRUWHShape hiddenSize inputSize = '[3 * hiddenSize, hiddenSize]

type GRUBIShape hiddenSize inputSize = '[3 * hiddenSize]

type GRUBHShape hiddenSize inputSize = '[3 * hiddenSize]

type family GRURImpl (inputSize :: Nat) (hiddenSize :: Nat) (numLayers :: Nat) (directionality :: RNNDirectionality) :: [[Nat]] where
  GRURImpl inputSize hiddenSize 1 'Unidirectional =
    '[ GRUWIShape hiddenSize inputSize,
       GRUWHShape hiddenSize inputSize,
       GRUBIShape hiddenSize inputSize,
       GRUBHShape hiddenSize inputSize
     ]
  GRURImpl inputSize hiddenSize numLayers 'Unidirectional =
    GRURImpl inputSize hiddenSize (numLayers - 1) 'Unidirectional
      ++ '[ GRUWIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
            GRUWHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
            GRUBIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
            GRUBHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional)
          ]
  GRURImpl inputSize hiddenSize 1 'Bidirectional =
    '[ GRUWIShape hiddenSize inputSize,
       GRUWHShape hiddenSize inputSize,
       GRUBIShape hiddenSize inputSize,
       GRUBHShape hiddenSize inputSize,
       GRUWIShape hiddenSize inputSize,
       GRUWHShape hiddenSize inputSize,
       GRUBIShape hiddenSize inputSize,
       GRUBHShape hiddenSize inputSize
     ]
  GRURImpl inputSize hiddenSize numLayers 'Bidirectional =
    GRURImpl inputSize hiddenSize (numLayers - 1) 'Bidirectional
      ++ '[ GRUWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            GRUWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            GRUBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            GRUBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            GRUWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            GRUWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            GRUBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
            GRUBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional)
          ]

type family GRUR' (shapes :: [[Nat]]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) :: [a] where
  GRUR' '[] dtype device = '[]
  GRUR' (shape ': shapes) dtype device = Tensor device dtype shape ': GRUR' shapes dtype device

type GRUR inputSize hiddenSize numLayers directionality dtype device = GRUR' (GRURImpl inputSize hiddenSize numLayers directionality) dtype device

-- | gru
-- Parameters for this ATen function are non-trivially provided.  See the
-- `Typed.NN.GRU` module for doctests.
gru ::
  forall
    shapeOrder
    directionality
    numLayers
    seqLen
    batchSize
    inputSize
    outputSize
    hiddenSize
    inputShape
    outputShape
    hcShape
    tensorParameters
    dtype
    device.
  ( KnownNat numLayers,
    KnownRNNShapeOrder shapeOrder,
    KnownRNNDirectionality directionality,
    outputSize ~ (hiddenSize * NumberOfDirections directionality),
    inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
    outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
    hcShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
    tensorParameters ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
    ATen.Castable (HList tensorParameters) [D.ATenTensor]
  ) =>
  HList tensorParameters ->
  Double ->
  Bool ->
  Tensor device dtype hcShape ->
  Tensor device dtype inputShape ->
  ( Tensor device dtype outputShape,
    Tensor device dtype hcShape
  )
gru :: forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (inputSize :: Nat)
       (outputSize :: Nat) (hiddenSize :: Nat) (inputShape :: [Nat])
       (outputShape :: [Nat]) (hcShape :: [Nat]) (tensorParameters :: [k])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ForeignPtr Tensor]) =>
HList tensorParameters
-> Double
-> Bool
-> Tensor device dtype hcShape
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gru HList tensorParameters
tensorParameters Double
dropoutProb Bool
dropoutOn Tensor device dtype hcShape
hc Tensor device dtype inputShape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
ATen.cast9
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.gru_ttlbldbbb
      Tensor device dtype inputShape
input
      Tensor device dtype hcShape
hc
      HList tensorParameters
tensorParameters
      Bool
hasBiases
      Int64
numLayers
      Double
dropoutProb
      Bool
dropoutOn
      (forall (directionality :: RNNDirectionality).
KnownRNNDirectionality directionality =>
Bool
rnnBidirectional @directionality)
      (forall (shapeOrder :: RNNShapeOrder).
KnownRNNShapeOrder shapeOrder =>
Bool
rnnBatchFirst @shapeOrder)
  where
    hasBiases :: Bool
hasBiases = Bool
True
    numLayers :: I.Int64
    numLayers :: Int64
numLayers = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @numLayers

-- | gruCell
--
-- >>> dtype &&& shape $ gruCell (ones :: CPUTensor 'D.Float '[9,2]) (ones :: CPUTensor 'D.Float '[9,3]) (ones :: CPUTensor 'D.Float '[9]) (ones :: CPUTensor 'D.Float '[9]) (ones :: CPUTensor 'D.Float '[2,3]) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,3])
gruCell ::
  forall inputSize hiddenSize batchSize dtype device.
  Tensor device dtype '[3 * hiddenSize, inputSize] ->
  Tensor device dtype '[3 * hiddenSize, hiddenSize] ->
  Tensor device dtype '[3 * hiddenSize] ->
  Tensor device dtype '[3 * hiddenSize] ->
  Tensor device dtype '[batchSize, hiddenSize] ->
  Tensor device dtype '[batchSize, inputSize] ->
  Tensor device dtype '[batchSize, hiddenSize]
gruCell :: forall (inputSize :: Nat) (hiddenSize :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[3 * hiddenSize, inputSize]
-> Tensor device dtype '[3 * hiddenSize, hiddenSize]
-> Tensor device dtype '[3 * hiddenSize]
-> Tensor device dtype '[3 * hiddenSize]
-> Tensor device dtype '[batchSize, hiddenSize]
-> Tensor device dtype '[batchSize, inputSize]
-> Tensor device dtype '[batchSize, hiddenSize]
gruCell Tensor device dtype '[3 * hiddenSize, inputSize]
wi Tensor device dtype '[3 * hiddenSize, hiddenSize]
wh Tensor device dtype '[3 * hiddenSize]
bi Tensor device dtype '[3 * hiddenSize]
bh Tensor device dtype '[batchSize, hiddenSize]
hx Tensor device dtype '[batchSize, inputSize]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.Managed.gru_cell_tttttt Tensor device dtype '[batchSize, inputSize]
input Tensor device dtype '[batchSize, hiddenSize]
hx Tensor device dtype '[3 * hiddenSize, inputSize]
wi Tensor device dtype '[3 * hiddenSize, hiddenSize]
wh Tensor device dtype '[3 * hiddenSize]
bi Tensor device dtype '[3 * hiddenSize]
bh

-- rnn_tanh_cell :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- rnn_tanh_cell _input _hx _w_ih _w_hh _b_ih _b_hh = unsafePerformIO $ (ATen.cast6 ATen.Managed.rnn_tanh_cell_tttttt) _input _hx _w_ih _w_hh _b_ih _b_hh

-- rnn_relu_cell :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- rnn_relu_cell _input _hx _w_ih _w_hh _b_ih _b_hh = unsafePerformIO $ (ATen.cast6 ATen.Managed.rnn_relu_cell_tttttt) _input _hx _w_ih _w_hh _b_ih _b_hh

-- quantized_lstm :: Tensor device dtype shape -> [Tensor device dtype shape] -> [Tensor device dtype shape] -> Bool -> Int -> Double -> Bool -> Bool -> Bool -> DType -> (Tensor device dtype shape,Tensor device dtype shape,Tensor device dtype shape)
-- quantized_lstm _input _hx _params _has_biases _num_layers _dropout _train _bidirectional _batch_first _dtype = unsafePerformIO $ (ATen.ATen.cast10 ATen.Managed.quantized_lstm_tllbldbbbs) _input _hx _params _has_biases _num_layers _dropout _train _bidirectional _batch_first _dtype

-- quantized_lstm_cell :: Tensor device dtype shape -> [Tensor device dtype shape] -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Float -> Float -> (Tensor device dtype shape,Tensor device dtype shape)
-- quantized_lstm_cell _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh = unsafePerformIO $ (ATen.ATen.cast14 ATen.Managed.quantized_lstm_cell_tlttttttttssss) _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh

-- quantized_gru_cell :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Float -> Float -> Tensor device dtype shape
-- quantized_gru_cell _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh = unsafePerformIO $ (ATen.ATen.cast14 ATen.Managed.quantized_gru_cell_ttttttttttssss) _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh

-- quantized_rnn_relu_cell :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Float -> Float -> Tensor device dtype shape
-- quantized_rnn_relu_cell _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh = unsafePerformIO $ (ATen.ATen.cast14 ATen.Managed.quantized_rnn_relu_cell_ttttttttttssss) _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh

-- quantized_rnn_tanh_cell :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Float -> Float -> Tensor device dtype shape
-- quantized_rnn_tanh_cell _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh = unsafePerformIO $ (ATen.ATen.cast14 ATen.Managed.quantized_rnn_tanh_cell_ttttttttttssss) _input _hx _w_ih _w_hh _b_ih _b_hh _packed_ih _packed_hh _col_offsets_ih _col_offsets_hh _scale_ih _scale_hh _zero_point_ih _zero_point_hh

-- masked_scatter :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- masked_scatter _input _mask _source = unsafePerformIO $ (ATen.cast3 ATen.Managed.masked_scatter_ttt) _input _mask _source

-- index_add :: Tensor device dtype shape -> Int -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- index_add _input _dim _index _source = unsafePerformIO $ (ATen.cast4 ATen.Managed.index_add_tltt) _input _dim _index _source

-- scatter_add :: Tensor device dtype shape -> Int -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- scatter_add _input _dim _index _src = unsafePerformIO $ (ATen.cast4 ATen.Managed.scatter_add_tltt) _input _dim _index _src

-- addbmm :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Tensor device dtype shape
-- addbmm _input _batch1 _batch2 _beta _alpha = unsafePerformIO $ (ATen.cast5 ATen.Managed.addbmm_tttss) _input _batch1 _batch2 _beta _alpha

-- cross :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- cross _input _other _dim = unsafePerformIO $ (ATen.cast3 ATen.Managed.cross_ttl) _input _other _dim

type family MatrixOrMatrixBatch (shape :: [Nat]) :: [Nat] where
  MatrixOrMatrixBatch (n : m : '[]) = '[n, m]
  MatrixOrMatrixBatch (b : n : m : '[]) = '[b, n, m]
  MatrixOrMatrixBatch _ = TypeError (Text "The input must be matrix or a batch of matrices.")

-- | triu
-- TODO: triu is not implemented for D.Bool, or maybe numeric type is lifted?
--
-- >>> t = ones :: CPUTensor 'D.Float '[3, 4]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ triu 0 t
-- (Float,([3,4],[[1.0,1.0,1.0,1.0],[0.0,1.0,1.0,1.0],[0.0,0.0,1.0,1.0]]))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ triu 1 t
-- (Float,([3,4],[[0.0,1.0,1.0,1.0],[0.0,0.0,1.0,1.0],[0.0,0.0,0.0,1.0]]))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ triu (-1) t
-- (Float,([3,4],[[1.0,1.0,1.0,1.0],[1.0,1.0,1.0,1.0],[0.0,1.0,1.0,1.0]]))
triu ::
  forall shape dtype device.
  (shape ~ MatrixOrMatrixBatch shape) =>
  -- | diagonal
  Int ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
triu :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
triu Int
diagonal Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.triu_tl Tensor device dtype shape
input Int
diagonal

-- | tril
-- TODO: tril is not implemented for D.Bool, or maybe numeric type is lifted?
--
-- >>> t = ones :: CPUTensor 'D.Float '[3, 4]
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ tril 0 t
-- (Float,([3,4],[[1.0,0.0,0.0,0.0],[1.0,1.0,0.0,0.0],[1.0,1.0,1.0,0.0]]))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ tril 1 t
-- (Float,([3,4],[[1.0,1.0,0.0,0.0],[1.0,1.0,1.0,0.0],[1.0,1.0,1.0,1.0]]))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [[Float]]) $ tril (-1) t
-- (Float,([3,4],[[0.0,0.0,0.0,0.0],[1.0,0.0,0.0,0.0],[1.0,1.0,0.0,0.0]]))
tril ::
  forall shape dtype device.
  (shape ~ MatrixOrMatrixBatch shape) =>
  -- | diagonal
  Int ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
tril :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
tril Int
diagonal Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.tril_tl Tensor device dtype shape
input Int
diagonal

-- | trace
--
-- >>> dtype &&& shape $ trace (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
trace ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
trace :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
trace Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.trace_t) Tensor device dtype shape
_input

-- take :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- take _input _index = unsafePerformIO $ (ATen.cast2 ATen.Managed.take_tt) _input _index

-- index_select :: Tensor device dtype shape -> Int -> Tensor device dtype shape -> Tensor device dtype shape
-- index_select _input _dim _index = unsafePerformIO $ (ATen.cast3 ATen.Managed.index_select_tlt) _input _dim _index

maskedSelect ::
  forall shape shape' shape'' dtype device.
  (shape'' ~ Broadcast shape shape') =>
  Tensor device 'D.Bool shape ->
  Tensor device dtype shape' ->
  UnknownShapeTensor device dtype
maskedSelect :: forall (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape
-> Tensor device dtype shape' -> UnknownShapeTensor device dtype
maskedSelect Tensor device 'Bool shape
mask Tensor device dtype shape'
input = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype shape -> UnknownShapeTensor device dtype
UnknownShapeTensor forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.masked_select_tt Tensor device dtype shape'
input Tensor device 'Bool shape
mask

-- | nonzero
--
-- >>> dtype &&& shape $ nonzero (zeros :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
nonzero ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
nonzero :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
nonzero Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.nonzero_t) Tensor device dtype shape
_input

-- nonzero_numpy :: Tensor device dtype shape -> [Tensor device dtype shape]
-- nonzero_numpy _input = unsafePerformIO $ (ATen.cast1 ATen.Managed.nonzero_numpy_t) _input

-- | GatherDimImpl
--
-- >>> :kind! GatherDimImpl '[2, 1, 1] '[2, 4, 1] 1
-- GatherDimImpl '[2, 1, 1] '[2, 4, 1] 1 :: Maybe [Natural]
-- = 'Just '[2, 4, 1]
-- >>> :kind! GatherDimImpl '[2, 1, 1] '[2, 4, 2] 1
-- GatherDimImpl '[2, 1, 1] '[2, 4, 2] 1 :: Maybe [Natural]
-- = 'Nothing
-- >>> :kind! GatherDimImpl '[2, 1, 1] '[2, 0, 1] 1
-- GatherDimImpl '[2, 1, 1] '[2, 0, 1] 1 :: Maybe [Natural]
-- = 'Nothing
-- >>> :kind! GatherDimImpl '[2, 1, 1] '[2, 1] 1
-- GatherDimImpl '[2, 1, 1] '[2, 1] 1 :: Maybe [Natural]
-- = 'Nothing
-- >>> :kind! GatherDimImpl '[2, 1, 1] '[2, 1, 3] 2
-- GatherDimImpl '[2, 1, 1] '[2, 1, 3] 2 :: Maybe [Natural]
-- = 'Just '[2, 1, 3]
type family GatherDimImpl (shape :: [Nat]) (shape' :: [Nat]) (dim :: Nat) :: Maybe [Nat] where
  GatherDimImpl (x ': xs) (y ': xs) 0 = If (1 <=? y) (Just (y ': xs)) Nothing
  GatherDimImpl (x ': xs) (x ': ys) dim = AppendToMaybe x (GatherDimImpl xs ys (dim - 1))
  GatherDimImpl _ _ _ = Nothing

type family GatherDimCheck (shape :: [a]) (shape' :: [a]) (dim :: Nat) (result :: Maybe [a]) :: [a] where
  GatherDimCheck shape shape' dim Nothing =
    TypeError
      ( Text "Cannot gather the tensor at dimension "
          :<>: ShowType dim
          :<>: Text " using index of shape "
          :<>: ShowType shape'
      )
  GatherDimCheck _ _ _ (Just shape'') = shape''

-- | Calculate the output shape of a gather operation for a given index shape along a given axis
--
-- >>> :kind! GatherDim '[2, 1, 1] '[2, 1, 3] 2
-- GatherDim '[2, 1, 1] '[2, 1, 3] 2 :: [Natural]
-- = '[2, 1, 3]
type GatherDim shape shape' dim = GatherDimCheck shape shape' dim (GatherDimImpl shape shape' dim)

-- | gather values along an axis for a specified dimension.
gatherDim ::
  forall dim shape shape' dtype device.
  (KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
  -- | the indices of elements to gather
  Tensor device 'D.Int64 shape' ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape'
gatherDim :: forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
gatherDim Tensor device 'Int64 shape'
index Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.gather_tltb Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) Tensor device 'Int64 shape'
index Bool
False

-- addcmul :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Tensor device dtype shape
-- addcmul _input _tensor1 _tensor2 _value = unsafePerformIO $ (ATen.cast4 ATen.Managed.addcmul_ttts) _input _tensor1 _tensor2 _value

-- addcdiv :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Float -> Tensor device dtype shape
-- addcdiv _input _tensor1 _tensor2 _value = unsafePerformIO $ (ATen.cast4 ATen.Managed.addcdiv_ttts) _input _tensor1 _tensor2 _value

-- lstsq :: Tensor device dtype shape -> Tensor device dtype shape -> (Tensor device dtype shape,Tensor device dtype shape)
-- lstsq _input _A = unsafePerformIO $ (ATen.cast2 ATen.Managed.lstsq_tt) _input _A

-- triangular_solve :: Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Bool -> Bool -> (Tensor device dtype shape,Tensor device dtype shape)
-- triangular_solve _input _A _upper _transpose _unitriangular = unsafePerformIO $ (ATen.cast5 ATen.Managed.triangular_solve_ttbbb) _input _A _upper _transpose _unitriangular

-- qr :: Tensor device dtype shape -> Bool -> (Tensor device dtype shape,Tensor device dtype shape)
-- qr _input _some = unsafePerformIO $ (ATen.cast2 ATen.Managed.qr_tb) _input _some

-- ormqr :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Bool -> Bool -> Tensor device dtype shape
-- ormqr _input _input2 _input3 _left _transpose = unsafePerformIO $ (ATen.cast5 ATen.Managed.ormqr_tttbb) _input _input2 _input3 _left _transpose

-- lu_solve :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- lu_solve _input _LU_data _LU_pivots = unsafePerformIO $ (ATen.cast3 ATen.Managed.lu_solve_ttt) _input _LU_data _LU_pivots

-- | lgamma function
--
-- >>> dtype &&& shape $ lgamma (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
lgamma ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
lgamma :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
lgamma Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.lgamma_t Tensor device dtype shape
input

-- | digamma function
--
-- >>> dtype &&& shape $ digamma (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
digamma ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
digamma :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
digamma Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.digamma_t Tensor device dtype shape
input

-- | polygamma function
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
polygamma ::
  forall shape dtype device.
  -- | order
  Int ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
polygamma :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> Tensor device dtype shape -> Tensor device dtype shape
polygamma Int
n Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.polygamma_lt Int
n Tensor device dtype shape
input

-- | inverse of the error function
--
-- >>> dtype &&& shape $ erfinv (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
erfinv ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
erfinv :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
erfinv Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.erfinv_t Tensor device dtype shape
input

-- dist :: Tensor device dtype shape -> Tensor device dtype shape -> Float -> Tensor device dtype shape
-- dist _input _other _p = unsafePerformIO $ (ATen.cast3 ATen.Managed.dist_tts) _input _other _p

-- atan2 :: Tensor device dtype shape -> Tensor device dtype shape -> Tensor device dtype shape
-- atan2 _input _other = unsafePerformIO $ (ATen.cast2 ATen.Managed.atan2_tt) _input _other

-- histc :: Tensor device dtype shape -> Int -> Float -> Float -> Tensor device dtype shape
-- histc _input _bins _min _max = unsafePerformIO $ (ATen.cast4 ATen.Managed.histc_tlss) _input _bins _min _max

-- | minAll
--
-- >>> dtype &&& shape $ minAll (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
minAll ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
minAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype '[]
minAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.min_t Tensor device dtype shape
input

type family DropValue (shape :: [Nat]) (i :: Nat) :: [Nat] where
  DropValue '[] _ = TypeError (Text "Can not find a element in the list.")
  DropValue (x : xs) 0 = xs
  DropValue (x : xs) i = x ': DropValue xs (i -1)

type family DropNamedValue (shape :: Shape) (i :: Size) :: Shape where
  DropNamedValue '[] _ = TypeError (Text "Can not find a element in the list.")
  DropNamedValue (x : xs) x = xs
  DropNamedValue (x : xs) y = x ': DropNamedValue xs y

-- | minDim
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ fst $ minDim @0 t
-- (Float,[4,5])
-- >>> dtype &&& shape $ fst $ minDim @1 t
-- (Float,[3,5])
-- >>> dtype &&& shape $ fst $ minDim @2 t
-- (Float,[3,4])
minDim ::
  forall d shape dtype device.
  (KnownNat d) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  ( Tensor device dtype (DropValue shape d),
    Tensor device 'D.Int64 (DropValue shape d)
  )
minDim :: forall (d :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownNat d =>
Tensor device dtype shape
-> (Tensor device dtype (DropValue shape d),
    Tensor device 'Int64 (DropValue shape d))
minDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.min_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @d)

-- | maxAll
--
-- >>> dtype &&& shape $ maxAll (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
maxAll ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype '[]
maxAll :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype '[]
maxAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.max_t Tensor device dtype shape
input

-- | maxDim
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ fst $ maxDim @0 t
-- (Float,[4,5])
-- >>> dtype &&& shape $ fst $ maxDim @1 t
-- (Float,[3,5])
-- >>> dtype &&& shape $ fst $ maxDim @2 t
-- (Float,[3,4])
maxDim ::
  forall d shape dtype device.
  (KnownNat d) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  ( Tensor device dtype (DropValue shape d),
    Tensor device 'D.Int64 (DropValue shape d)
  )
maxDim :: forall (d :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownNat d =>
Tensor device dtype shape
-> (Tensor device dtype (DropValue shape d),
    Tensor device 'Int64 (DropValue shape d))
maxDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.max_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @d)

type family HasDim (dim :: Nat) (shape :: [Nat]) :: Constraint where
  HasDim _ '[] = TypeError (Text "The dimension of the argument is incorrect.")
  HasDim 0 (_ ': _) = ()
  HasDim n (_ ': xs) = HasDim (n -1) xs

-- | sortDim
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ fst $ sortDim @0 True t
-- (Float,[3,4,5])
-- >>> dtype &&& shape $ snd $ sortDim @0 True t
-- (Int64,[3,4,5])
sortDim ::
  forall dim shape dtype device.
  ( KnownNat dim,
    HasDim dim shape
  ) =>
  Bool ->
  Tensor device dtype shape ->
  ( Tensor device dtype shape,
    Tensor device D.Int64 shape
  )
sortDim :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, HasDim dim shape) =>
Bool
-> Tensor device dtype shape
-> (Tensor device dtype shape, Tensor device 'Int64 shape)
sortDim Bool
_descending Tensor device dtype shape
_input =
  let (Tensor
a, Tensor
b) = Tensor -> (Tensor, Tensor)
func (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
_input)
   in (forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
a, forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
b)
  where
    func :: D.Tensor -> (D.Tensor, D.Tensor)
    func :: Tensor -> (Tensor, Tensor)
func Tensor
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.sort_tlb) Tensor
_input Int
_dim Bool
_descending
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim

-- | sortNamedDim
--
-- >>> import Torch.Typed.Factories
-- >>> import Data.Default.Class
-- >>> t = def :: NamedTensor '( D.CPU, 0) 'D.Float '[Vector 3, Vector 4, Vector 5]
-- >>> dtype &&& shape $ fst $ sortNamedDim @(Vector 3) True t
-- (Float,[3,4,5])
-- >>> dtype &&& shape $ snd $ sortNamedDim @(Vector 3) True t
-- (Int64,[3,4,5])
sortNamedDim ::
  forall dim shape dtype device.
  ( KnownNat (FindDim dim shape)
  ) =>
  Bool ->
  NamedTensor device dtype shape ->
  ( NamedTensor device dtype shape,
    NamedTensor device D.Int64 shape
  )
sortNamedDim :: forall (dim :: Size) (shape :: Shape) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownNat (FindDim dim shape) =>
Bool
-> NamedTensor device dtype shape
-> (NamedTensor device dtype shape,
    NamedTensor device 'Int64 shape)
sortNamedDim Bool
_descending NamedTensor device dtype shape
_input =
  let (Tensor
a, Tensor
b) = Tensor -> (Tensor, Tensor)
func (forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype shape
_input)
   in (forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: Shape) (shape1 :: [Nat]).
(shape1 ~ ToNats shape) =>
Tensor device dtype shape1 -> NamedTensor device dtype shape
FromTensor forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
a, forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: Shape) (shape1 :: [Nat]).
(shape1 ~ ToNats shape) =>
Tensor device dtype shape1 -> NamedTensor device dtype shape
FromTensor forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
b)
  where
    func :: D.Tensor -> (D.Tensor, D.Tensor)
    func :: Tensor -> (Tensor, Tensor)
func Tensor
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.sort_tlb) Tensor
_input Int
_dim Bool
_descending
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @(FindDim dim shape)

-- | argSortDim
--
-- >>> t = ones :: CPUTensor 'D.Float '[3,4,5]
-- >>> dtype &&& shape $ argSortDim @0 True t
-- (Int64,[3,4,5])
argSortDim ::
  forall dim shape dtype device.
  ( KnownNat dim,
    HasDim dim shape
  ) =>
  Bool ->
  Tensor device dtype shape ->
  Tensor device D.Int64 shape
argSortDim :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, HasDim dim shape) =>
Bool -> Tensor device dtype shape -> Tensor device 'Int64 shape
argSortDim Bool
_descending Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argsort_tlb) Tensor device dtype shape
_input Int
_dim Bool
_descending
  where
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim

argSortNamedDim ::
  forall dim shape dtype device.
  ( KnownNat (FindDim dim shape)
  ) =>
  Bool ->
  NamedTensor device dtype shape ->
  NamedTensor device D.Int64 shape
argSortNamedDim :: forall (dim :: Size) (shape :: Shape) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownNat (FindDim dim shape) =>
Bool
-> NamedTensor device dtype shape
-> NamedTensor device 'Int64 shape
argSortNamedDim Bool
_descending NamedTensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argsort_tlb) NamedTensor device dtype shape
_input Int
_dim Bool
_descending
  where
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @(FindDim dim shape)

type family TopKCheck (k :: Nat) (shape :: [Nat]) (dim :: Nat) (satd :: Maybe Nat) (result :: Maybe a) :: a where
  TopKCheck _ shape dim _ Nothing = DimOutOfBound shape dim
  TopKCheck _ shape dim Nothing _ = DimOutOfBound shape dim
  TopKCheck k shape dim (Just v) (Just result) = If (k <=? v) result (TypeError (Text "k must be less than or equal to the number of elements in the requested dimension."))

type TopK k shape dim = TopKCheck k shape dim (ExtractDim dim shape) (ReplaceDim dim shape k)

type family TopKDeviceAndDTypeCheck dtype (device :: (D.DeviceType, Nat)) :: Constraint where
  TopKDeviceAndDTypeCheck D.Bool _ = (TypeError (Text "topk is not defined for Bool tensors."))
  TopKDeviceAndDTypeCheck D.Half '(D.CPU, _) = (TypeError (Text "topk is not defined for Half types on CPU."))
  TopKDeviceAndDTypeCheck _ _ = ()

-- | Returns the k largest (if largest is `True`) elements of the given input tensor along a given dimension.
--
-- >>> topk @3 @1 True True (ones :: CPUTensor 'D.Float '[2,3])
-- (Tensor Float [2,3] [[ 1.0000   ,  1.0000   ,  1.0000   ],
--                     [ 1.0000   ,  1.0000   ,  1.0000   ]],Tensor Int64 [2,3] [[ 0,  1,  2],
--                     [ 0,  1,  2]])
-- >>> topk @0 @1 True True (ones :: CPUTensor 'D.Float '[2,3])
-- (Tensor Float [2,0] [[],
--                     []],Tensor Int64 [2,0] [[],
--                     []])
topk ::
  forall k dim shape' shape dtype device.
  ( KnownNat k,
    KnownNat dim,
    All KnownNat shape,
    TopKDeviceAndDTypeCheck dtype device,
    shape' ~ TopK k shape dim
  ) =>
  -- | if we're returning the top k largest (or, if False, the top k smallest)
  Bool ->
  -- | if the resulting k elements are themselves sorted
  Bool ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  (Tensor device dtype shape', Tensor device 'D.Int64 shape')
topk :: forall (k :: Nat) (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat k, KnownNat dim, All KnownNat shape,
 TopKDeviceAndDTypeCheck dtype device, shape' ~ TopK k shape dim) =>
Bool
-> Bool
-> Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
topk Bool
_largest Bool
_sorted Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> Int64
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.topk_tllbb) Tensor device dtype shape
_input Int
_k Int
_dim Bool
_largest Bool
_sorted
  where
    _k :: Int
_k = forall (n :: Nat). KnownNat n => Int
natValI @k
    _dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim

-- renorm :: Tensor device dtype shape -> Float -> Int -> Float -> Tensor device dtype shape
-- renorm _input _p _dim _maxnorm = unsafePerformIO $ (ATen.cast4 ATen.Managed.renorm_tsls) _input _p _dim _maxnorm

-- equal :: Tensor device dtype shape -> Tensor device dtype shape -> Bool
-- equal _input _other = unsafePerformIO $ (ATen.cast2 ATen.Managed.equal_tt) _input _other

-- | alias
--
-- >>> dtype &&& shape $ alias (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
alias ::
  forall shape dtype device.
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
alias :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
alias Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.alias_t) Tensor device dtype shape
_input

-- | L1 loss
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ l1Loss @ReduceNone (ones :: CPUTensor 'D.Float '[2,2]) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
-- >>> dtype &&& shape $ l1Loss @ReduceSum (ones :: CPUTensor 'D.Float '[2,2]) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
l1Loss ::
  forall reduction shape dtype device.
  (KnownReduction reduction) =>
  -- | prediciton
  Tensor device dtype shape ->
  -- | target
  Tensor device dtype shape ->
  -- | loss
  Tensor device dtype (ConditionalReduction shape reduction)
l1Loss :: forall (reduction :: Reduction) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownReduction reduction =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype (ConditionalReduction shape reduction)
l1Loss Tensor device dtype shape
prediction Tensor device dtype shape
target =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.l1_loss_ttl Tensor device dtype shape
prediction Tensor device dtype shape
target (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)

-- multi_margin_loss :: Tensor device dtype shape -> Tensor device dtype shape -> Float -> Float -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- multi_margin_loss _input _target _p _margin _weight _reduction = unsafePerformIO $ (ATen.cast6 ATen.Managed.multi_margin_loss_ttsstl) _input _target _p _margin _weight _reduction

-- multilabel_margin_loss :: Tensor device dtype shape -> Tensor device dtype shape -> Int -> Tensor device dtype shape
-- multilabel_margin_loss _input _target _reduction = unsafePerformIO $ (ATen.cast3 ATen.Managed.multilabel_margin_loss_ttl) _input _target _reduction

-- | negative log likelihood loss
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- See https://pytorch.org/docs/stable/nn.functional.html?highlight=nll_loss#torch.nn.functional.nll_loss.
--
-- >>> input <- randn @'[3, 5] @'D.Float @'( 'D.CPU, 0)
-- >>> target = fromJust [1, 0, 4] :: CPUTensor 'D.Int64 '[3]
-- >>> weight = ones @'[5] @'D.Float @'( 'D.CPU, 0)
-- >>> dtype &&& shape $ nllLoss @ReduceNone @3 @5 @'[] weight (-100) (logSoftmax @1 input) target
-- (Float,[3])
-- >>> dtype &&& shape $ nllLoss @ReduceMean @3 @5 @'[] weight (-100) (logSoftmax @1 input) target
-- (Float,[])
-- >>> input <- randn @'[3, 5, 2] @'D.Float @'( 'D.CPU, 0)
-- >>> target = fromJust [[1, 1], [0, 1], [4, 0]] :: CPUTensor 'D.Int64 '[3, 2]
-- >>> weight = ones @'[5] @'D.Float @'( 'D.CPU, 0)
-- >>> dtype &&& shape $ nllLoss @ReduceNone @3 @5 @'[2] weight (-100) (logSoftmax @1 input) target
-- (Float,[3,2])
-- >>> dtype &&& shape $ nllLoss @ReduceMean @3 @5 @'[2] weight (-100) (logSoftmax @1 input) target
-- (Float,[])
-- >>> input <- randn @'[3, 5, 1, 2] @'D.Float @'( 'D.CPU, 0)
-- >>> target = fromJust [[[1, 1]], [[0, 1]], [[4, 0]]] :: CPUTensor 'D.Int64 '[3, 1, 2]
-- >>> weight = ones @'[5] @'D.Float @'( 'D.CPU, 0)
-- >>> dtype &&& shape $ nllLoss @ReduceNone @3 @5 @'[1, 2] weight (-100) (logSoftmax @1 input) target
-- (Float,[3,1,2])
-- >>> dtype &&& shape $ nllLoss @ReduceMean @3 @5 @'[1, 2] weight (-100) (logSoftmax @1 input) target
-- (Float,[])
-- >>> input <- randn @'[3, 5, 2, 1, 2] @'D.Float @'( 'D.CPU, 0)
-- >>> target = fromJust [[[[1, 1]], [[0, 2]]], [[[0, 1]], [[1, 0]]], [[[4, 0]], [[1, 2]]]] :: CPUTensor 'D.Int64 '[3, 2, 1, 2]
-- >>> weight = ones @'[5] @'D.Float @'( 'D.CPU, 0)
-- >>> dtype &&& shape $ nllLoss @ReduceNone @3 @5 @'[2, 1, 2] weight (-100) (logSoftmax @1 input) target
-- (Float,[3,2,1,2])
-- >>> dtype &&& shape $ nllLoss @ReduceMean @3 @5 @'[2, 1, 2] weight (-100) (logSoftmax @1 input) target
-- (Float,[])
nllLoss ::
  forall reduction n c ds dtype device.
  (KnownReduction reduction, KnownNat n, KnownNat c, KnownShape ds) =>
  -- | weight
  Tensor device dtype '[c] ->
  -- | ignore which index
  Int ->
  -- | prediction
  Tensor device dtype (n ': c ': ds) ->
  -- | target
  Tensor device 'D.Int64 (n ': ds) ->
  -- | loss
  Tensor device dtype (ConditionalReduction (n ': ds) reduction)
nllLoss :: forall (reduction :: Reduction) (n :: Nat) (c :: Nat) (ds :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction, KnownNat n, KnownNat c,
 KnownShape ds) =>
Tensor device dtype '[c]
-> Int
-> Tensor device dtype (n : c : ds)
-> Tensor device 'Int64 (n : ds)
-> Tensor device dtype (ConditionalReduction (n : ds) reduction)
nllLoss Tensor device dtype '[c]
weight Int
ignoreIndex Tensor device dtype (n : c : ds)
prediction Tensor device 'Int64 (n : ds)
target = case forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @ds of
  [] ->
    forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
      forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
        ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.nll_loss_tttll
        Tensor device dtype (n : c : ds)
prediction
        Tensor device 'Int64 (n : ds)
target
        Tensor device dtype '[c]
weight
        (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
        Int
ignoreIndex
  [Int
_h, Int
_w] ->
    forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
      forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
        ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.nll_loss2d_tttll
        Tensor device dtype (n : c : ds)
prediction
        Tensor device 'Int64 (n : ds)
target
        Tensor device dtype '[c]
weight
        (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
        Int
ignoreIndex
  Int
h : [Int]
t -> case forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction of
    Int
0 -> forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int] -> Tensor -> Tensor
D.reshape ((forall (n :: Nat). KnownNat n => Int
natValI @n) forall a. a -> [a] -> [a]
: Int
h forall a. a -> [a] -> [a]
: [Int]
t)) forall a b. (a -> b) -> a -> b
$ Tensor
out
    Int
_ -> forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
out
    where
      t' :: [Int]
t' = [Int
1, forall (t :: Size) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Num a => a -> a -> a
(*) Int
h [Int]
t]
      input' :: Tensor
input' = [Int] -> Tensor -> Tensor
D.reshape (forall (n :: Nat). KnownNat n => Int
natValI @n forall a. a -> [a] -> [a]
: forall (n :: Nat). KnownNat n => Int
natValI @c forall a. a -> [a] -> [a]
: [Int]
t') (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype (n : c : ds)
prediction)
      target' :: Tensor
target' = [Int] -> Tensor -> Tensor
D.reshape (forall (n :: Nat). KnownNat n => Int
natValI @n forall a. a -> [a] -> [a]
: [Int]
t') (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device 'Int64 (n : ds)
target)
      out :: Tensor
out =
        forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
          forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
            ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.nll_loss2d_tttll
            Tensor
input'
            Tensor
target'
            Tensor device dtype '[c]
weight
            (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
            Int
ignoreIndex

-- | smooth L1 loss
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ smoothL1Loss @ReduceNone (ones :: CPUTensor 'D.Float '[2,2]) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
-- >>> dtype &&& shape $ smoothL1Loss @ReduceSum (ones :: CPUTensor 'D.Float '[2,2]) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
smoothL1Loss ::
  forall reduction shape dtype device.
  (KnownReduction reduction) =>
  -- | prediction
  Tensor device dtype shape ->
  -- | target
  Tensor device dtype shape ->
  -- | loss
  Tensor device dtype (ConditionalReduction shape reduction)
smoothL1Loss :: forall (reduction :: Reduction) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownReduction reduction =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype (ConditionalReduction shape reduction)
smoothL1Loss Tensor device dtype shape
prediction Tensor device dtype shape
target =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.smooth_l1_loss_ttl Tensor device dtype shape
prediction Tensor device dtype shape
target (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)

-- | soft margin loss
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ softMarginLoss @ReduceNone (ones :: CPUTensor 'D.Float '[2,2]) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[2,2])
-- >>> dtype &&& shape $ softMarginLoss @ReduceSum (ones :: CPUTensor 'D.Float '[2,2]) (ones :: CPUTensor 'D.Float '[2,2])
-- (Float,[])
softMarginLoss ::
  forall reduction shape dtype device.
  (KnownReduction reduction) =>
  -- | prediction
  Tensor device dtype shape ->
  -- | target
  Tensor device dtype shape ->
  -- | loss
  Tensor device dtype (ConditionalReduction shape reduction)
softMarginLoss :: forall (reduction :: Reduction) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
KnownReduction reduction =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype (ConditionalReduction shape reduction)
softMarginLoss Tensor device dtype shape
prediciton Tensor device dtype shape
target =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.soft_margin_loss_ttl
      Tensor device dtype shape
prediciton
      Tensor device dtype shape
target
      (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)

-- | elu
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ elu 0.1 0.1 0.3 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
elu ::
  forall shape dtype a device.
  (D.Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
  -- | alpha
  a ->
  -- | scale
  a ->
  -- | input scale
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
elu :: forall (shape :: [Nat]) (dtype :: DType) a
       (device :: (DeviceType, Nat)).
(Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
a
-> a -> a -> Tensor device dtype shape -> Tensor device dtype shape
elu a
alpha a
scale a
inputScale Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.elu_tsss Tensor device dtype shape
input a
alpha a
scale a
inputScale

-- | glu
-- -- >>> dtype &&& shape $ glu (ones :: CPUTensor 'D.Float '[3,2]) 1
-- -- (Float,[3,1])
-- -- >>> dtype &&& shape $ glu (ones :: CPUTensor 'D.Float '[3,2]) 3
-- -- (Float,[3,2])
-- glu :: Tensor device dtype shape -> Int -> Tensor device dtype shape
-- glu _input _dim = unsafePerformIO $ (ATen.cast2 ATen.Managed.glu_tl) _input _dim

-- | hard tanh
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ hardTanh 0 1 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
hardTanh ::
  forall shape dtype device.
  -- | minimum value
  Float ->
  -- | maximum value
  Float ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
hardTanh :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Float
-> Float -> Tensor device dtype shape -> Tensor device dtype shape
hardTanh Float
min_val Float
max_val Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.hardtanh_tss Tensor device dtype shape
input Float
min_val Float
max_val

-- | leaky relu
--
-- >>> dtype &&& shape $ leakyRelu 0.01 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
leakyRelu ::
  forall a shape dtype device.
  (D.Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
  -- | negative slope
  a ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
leakyRelu :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
a -> Tensor device dtype shape -> Tensor device dtype shape
leakyRelu a
negativeSlope Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.leaky_relu_ts Tensor device dtype shape
input a
negativeSlope

-- | logarithm of the sigmoid
--
-- >>> dtype &&& shape $ logSigmoid (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
logSigmoid ::
  forall shape dtype device.
  (StandardFloatingPointDTypeValidation device dtype) =>
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
logSigmoid :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
logSigmoid Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log_sigmoid_t Tensor device dtype shape
input

-- | softplus
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- See https://pytorch.org/docs/stable/nn.functional.html?highlight=softplus#torch.nn.functional.softplus.
--
-- >>> dtype &&& shape &&& (\t -> D.asValue (toDynamic t) :: [[Float]]) $ softplus 1 20 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,([3,2],[[1.3132616,1.3132616],[1.3132616,1.3132616],[1.3132616,1.3132616]]))
softplus ::
  forall a shape dtype device.
  D.Scalar a =>
  a ->
  a ->
  Tensor device dtype shape ->
  Tensor device dtype shape
softplus :: forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
softplus a
beta a
threshold Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.softplus_tss Tensor device dtype shape
input a
beta a
threshold

-- | soft shrink
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> dtype &&& shape $ softShrink 0.2 (ones :: CPUTensor 'D.Float '[3,2])
-- (Float,[3,2])
softShrink ::
  forall shape dtype device.
  -- | lambda
  Float ->
  -- | input
  Tensor device dtype shape ->
  -- | output
  Tensor device dtype shape
softShrink :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Float -> Tensor device dtype shape -> Tensor device dtype shape
softShrink Float
lambda Tensor device dtype shape
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.softshrink_ts Tensor device dtype shape
input Float
lambda

-- | adaptive averaged 2-D pooling
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = adaptiveAvgPool2d @'(8,16) (ones :: CPUTensor 'D.Float '[1,3,16,32])
-- >>> shape t
-- [1,3,8,16]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8, 16]
adaptiveAvgPool2d ::
  forall outputSize channelSize inputSize0 inputSize1 batchSize dtype device.
  ( All
      KnownNat
      '[ channelSize,
         inputSize0,
         inputSize1,
         batchSize,
         Torch.Typed.Auxiliary.Fst outputSize,
         Torch.Typed.Auxiliary.Snd outputSize
       ]
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize]
adaptiveAvgPool2d :: forall (outputSize :: (Nat, Nat)) (channelSize :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
All
  KnownNat
  '[channelSize, inputSize0, inputSize1, batchSize, Fst outputSize,
    Snd outputSize] =>
Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, channelSize, Fst outputSize, Snd outputSize]
adaptiveAvgPool2d Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
      ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool2d_tl
      Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst outputSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd outputSize)] :: [Int])

-- | MKLDNN adaptive averaged 2-D pooling
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: broken?
-- TODO: only defined for MKLDNN device?
-- TODO: test for availability of MKLDNN device?
-- TODO: merge with adaptiveAvgPool2d and dispatch based on (availability of MKLDNN) device in the function body?
--
-- -- >>> t = mkldnnAdaptiveAvgPool2d @'(8,16) (toMKLDNN (ones :: CPUTensor 'D.Float '[1,3,16,32]))
-- -- >>> shape t
-- -- [1,3,8,16]
-- -- >>> :t t
-- -- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8, 16]
mkldnnAdaptiveAvgPool2d ::
  forall outputSize channelSize inputSize0 inputSize1 batchSize dtype device.
  ( All
      KnownNat
      '[ channelSize,
         inputSize0,
         inputSize1,
         batchSize,
         Torch.Typed.Auxiliary.Fst outputSize,
         Torch.Typed.Auxiliary.Snd outputSize
       ]
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize]
mkldnnAdaptiveAvgPool2d :: forall (outputSize :: (Nat, Nat)) (channelSize :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
All
  KnownNat
  '[channelSize, inputSize0, inputSize1, batchSize, Fst outputSize,
    Snd outputSize] =>
Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, channelSize, Fst outputSize, Snd outputSize]
mkldnnAdaptiveAvgPool2d Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
      ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool2d_tl
      Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst outputSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd outputSize)] :: [Int])

-- | adaptive averaged 3-D pooling
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = adaptiveAvgPool3d @'(8,16,2) (ones :: CPUTensor 'D.Float '[1,3,16,32,4])
-- >>> shape t
-- [1,3,8,16,2]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8, 16, 2]
adaptiveAvgPool3d ::
  forall
    outputSize
    channelSize
    inputSize0
    inputSize1
    inputSize2
    batchSize
    dtype
    device.
  ( All
      KnownNat
      '[ channelSize,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize,
         Fst3 outputSize,
         Snd3 outputSize,
         Trd3 outputSize
       ]
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize, Trd3 outputSize]
adaptiveAvgPool3d :: forall (outputSize :: (Nat, Nat, Nat)) (channelSize :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All
  KnownNat
  '[channelSize, inputSize0, inputSize1, inputSize2, batchSize,
    Fst3 outputSize, Snd3 outputSize, Trd3 outputSize] =>
Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     device
     dtype
     '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize,
       Trd3 outputSize]
adaptiveAvgPool3d Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
      ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool3d_tl
      Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
      ( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 outputSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 outputSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 outputSize)
        ] ::
          [Int]
      )

-- | adaptive 2-D max-pool
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> (t, t') = adaptiveMaxPool2d @'(8,16) (ones :: CPUTensor 'D.Float '[1,3,16,32])
-- >>> shape t
-- [1,3,8,16]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8, 16]
adaptiveMaxPool2d ::
  forall outputSize channelSize inputSize0 inputSize1 batchSize dtype device.
  ( All
      KnownNat
      '[ channelSize,
         inputSize0,
         inputSize1,
         batchSize,
         Torch.Typed.Auxiliary.Fst outputSize,
         Torch.Typed.Auxiliary.Snd outputSize
       ]
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
  -- | output
  ( Tensor device dtype '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize],
    Tensor device 'D.Int64 '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize]
  )
adaptiveMaxPool2d :: forall (outputSize :: (Nat, Nat)) (channelSize :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
All
  KnownNat
  '[channelSize, inputSize0, inputSize1, batchSize, Fst outputSize,
    Snd outputSize] =>
Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> (Tensor
      device
      dtype
      '[batchSize, channelSize, Fst outputSize, Snd outputSize],
    Tensor
      device
      'Int64
      '[batchSize, channelSize, Fst outputSize, Snd outputSize])
adaptiveMaxPool2d Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.adaptive_max_pool2d_tl
      Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst outputSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd outputSize)] :: [Int])

-- | adaptive 3-D max-pool
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> (t, t') = adaptiveMaxPool3d @'(8,16,2) (ones :: CPUTensor 'D.Float '[1,3,16,32,4])
-- >>> shape t
-- [1,3,8,16,2]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 8, 16, 2]
adaptiveMaxPool3d ::
  forall
    outputSize
    channelSize
    inputSize0
    inputSize1
    inputSize2
    batchSize
    dtype
    device.
  ( All
      KnownNat
      '[ channelSize,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize,
         Fst3 outputSize,
         Snd3 outputSize,
         Trd3 outputSize
       ]
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
  -- | output
  ( Tensor device dtype '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize, Trd3 outputSize],
    Tensor device 'D.Int64 '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize, Trd3 outputSize]
  )
adaptiveMaxPool3d :: forall (outputSize :: (Nat, Nat, Nat)) (channelSize :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All
  KnownNat
  '[channelSize, inputSize0, inputSize1, inputSize2, batchSize,
    Fst3 outputSize, Snd3 outputSize, Trd3 outputSize] =>
Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> (Tensor
      device
      dtype
      '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize,
        Trd3 outputSize],
    Tensor
      device
      'Int64
      '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize,
        Trd3 outputSize])
adaptiveMaxPool3d Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    (forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.adaptive_max_pool3d_tl)
      Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
      ( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 outputSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 outputSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 outputSize)
        ] ::
          [Int]
      )

-- | averaged 2-D pooling
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = avgPool2d @'(1,1) @'(1,1) @'(0,0) (ones :: CPUTensor 'D.Float '[1,3,4,5])
-- >>> shape t
-- [1,3,4,5]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4, 5]
avgPool2d ::
  forall
    kernelSize
    stride
    padding
    channelSize
    inputSize0
    inputSize1
    batchSize
    outputSize0
    outputSize1
    dtype
    device.
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst kernelSize,
         Torch.Typed.Auxiliary.Snd kernelSize,
         Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         channelSize,
         inputSize0,
         inputSize1,
         batchSize
       ],
    ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
avgPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
       (padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
     Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
     batchSize],
 ConvSideCheck
   inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1
   (Snd kernelSize)
   (Snd stride)
   (Snd padding)
   outputSize1) =>
Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
     device dtype '[batchSize, channelSize, outputSize0, outputSize1]
avgPool2d Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.avg_pool2d_tlllbbl
      Tensor
  device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
      Bool
False
      Bool
True
      (Int
1 :: Int)

-- | averaged 3-D pooling
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
--
-- >>> t = avgPool3d @'(1,1,1) @'(1,1,1) @'(0,0,0) (ones :: CPUTensor 'D.Float '[1,3,4,5,6])
-- >>> shape t
-- [1,3,4,5,6]
-- >>> :t t
-- t :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 3, 4, 5, 6]
avgPool3d ::
  forall
    kernelSize
    stride
    padding
    channelSize
    inputSize0
    inputSize1
    inputSize2
    batchSize
    outputSize0
    outputSize1
    outputSize2
    dtype
    device.
  ( All
      KnownNat
      '[ Fst3 kernelSize,
         Snd3 kernelSize,
         Trd3 kernelSize,
         Fst3 stride,
         Snd3 stride,
         Trd3 stride,
         Fst3 padding,
         Snd3 padding,
         Trd3 padding,
         channelSize,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize
       ],
    ConvSideCheck inputSize0 (Fst3 kernelSize) (Fst3 stride) (Fst3 padding) outputSize0,
    ConvSideCheck inputSize1 (Snd3 kernelSize) (Snd3 stride) (Snd3 padding) outputSize1,
    ConvSideCheck inputSize2 (Trd3 kernelSize) (Trd3 stride) (Trd3 padding) outputSize2
  ) =>
  -- | input
  Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
  -- | output
  Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
avgPool3d :: forall (kernelSize :: (Nat, Nat, Nat)) (stride :: (Nat, Nat, Nat))
       (padding :: (Nat, Nat, Nat)) (channelSize :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
       (outputSize2 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst3 kernelSize, Snd3 kernelSize, Trd3 kernelSize, Fst3 stride,
     Snd3 stride, Trd3 stride, Fst3 padding, Snd3 padding, Trd3 padding,
     channelSize, inputSize0, inputSize1, inputSize2, batchSize],
 ConvSideCheck
   inputSize0
   (Fst3 kernelSize)
   (Fst3 stride)
   (Fst3 padding)
   outputSize0,
 ConvSideCheck
   inputSize1
   (Snd3 kernelSize)
   (Snd3 stride)
   (Snd3 padding)
   outputSize1,
 ConvSideCheck
   inputSize2
   (Trd3 kernelSize)
   (Trd3 stride)
   (Trd3 padding)
   outputSize2) =>
Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     device
     dtype
     '[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
avgPool3d Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.avg_pool3d_tlllbbl
      Tensor
  device
  dtype
  '[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
      ( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 kernelSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 kernelSize),
          forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 kernelSize)
        ] ::
          [Int]
      )
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
      Bool
False
      Bool
True
      (Int
1 :: Int)

-- fractional_max_pool2d :: Tensor device dtype shape -> (Int,Int) -> (Int,Int) -> Tensor device dtype shape -> (Tensor device dtype shape,Tensor device dtype shape)
-- fractional_max_pool2d _input _kernel_size _output_size _random_samples = unsafePerformIO $ (ATen.cast4 ATen.Managed.fractional_max_pool2d_tllt) _input _kernel_size _output_size _random_samples

-- fractional_max_pool3d :: Tensor device dtype shape -> (Int,Int,Int) -> (Int,Int,Int) -> Tensor device dtype shape -> (Tensor device dtype shape,Tensor device dtype shape)
-- fractional_max_pool3d _input _kernel_size _output_size _random_samples = unsafePerformIO $ (ATen.cast4 ATen.Managed.fractional_max_pool3d_tllt) _input _kernel_size _output_size _random_samples

-- max_pool2d_with_indices :: Tensor device dtype shape -> (Int,Int) -> (Int,Int) -> (Int,Int) -> (Int,Int) -> Bool -> (Tensor device dtype shape,Tensor device dtype shape)
-- max_pool2d_with_indices _input _kernel_size _stride _padding _dilation _ceil_mode = unsafePerformIO $ (ATen.cast6 ATen.Managed.max_pool2d_with_indices_tllllb) _input _kernel_size _stride _padding _dilation _ceil_mode

-- max_pool3d_with_indices :: Tensor device dtype shape -> (Int,Int,Int) -> (Int,Int,Int) -> (Int,Int,Int) -> (Int,Int,Int) -> Bool -> (Tensor device dtype shape,Tensor device dtype shape)
-- max_pool3d_with_indices _input _kernel_size _stride _padding _dilation _ceil_mode = unsafePerformIO $ (ATen.cast6 ATen.Managed.max_pool3d_with_indices_tllllb) _input _kernel_size _stride _padding _dilation _ceil_mode

-- max_unpool2d :: Tensor device dtype shape -> Tensor device dtype shape -> (Int,Int) -> Tensor device dtype shape
-- max_unpool2d _input _indices _output_size = unsafePerformIO $ (ATen.cast3 ATen.Managed.max_unpool2d_ttl) _input _indices _output_size

-- max_unpool3d :: Tensor device dtype shape -> Tensor device dtype shape -> (Int,Int,Int) -> (Int,Int,Int) -> (Int,Int,Int) -> Tensor device dtype shape
-- max_unpool3d _input _indices _output_size _stride _padding = unsafePerformIO $ (ATen.cast5 ATen.Managed.max_unpool3d_ttlll) _input _indices _output_size _stride _padding

-- reflection_pad1d :: Tensor device dtype shape -> (Int,Int) -> Tensor device dtype shape
-- reflection_pad1d _input _padding = unsafePerformIO $ (ATen.cast2 ATen.Managed.reflection_pad1d_tl) _input _padding

-- reflection_pad2d :: Tensor device dtype shape -> (Int,Int,Int,Int) -> Tensor device dtype shape
-- reflection_pad2d _input _padding = unsafePerformIO $ (ATen.cast2 ATen.Managed.reflection_pad2d_tl) _input _padding

-- replication_pad1d :: Tensor device dtype shape -> (Int,Int) -> Tensor device dtype shape
-- replication_pad1d _input _padding = unsafePerformIO $ (ATen.cast2 ATen.Managed.replication_pad1d_tl) _input _padding

-- replication_pad2d :: Tensor device dtype shape -> (Int,Int,Int,Int) -> Tensor device dtype shape
-- replication_pad2d _input _padding = unsafePerformIO $ (ATen.cast2 ATen.Managed.replication_pad2d_tl) _input _padding

-- replication_pad3d :: Tensor device dtype shape -> (Int,Int,Int,Int,Int,Int) -> Tensor device dtype shape
-- replication_pad3d _input _padding = unsafePerformIO $ (ATen.cast2 ATen.Managed.replication_pad3d_tl) _input _padding

-- upsample_linear1d :: Tensor device dtype shape -> Int -> Bool -> Tensor device dtype shape
-- upsample_linear1d _input _output_size _align_corners = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_linear1d_tlb) _input _output_size _align_corners

type family Upsample2dCheck shape h w where
  Upsample2dCheck (b : c : w : h : '[]) h' w' =
    If
      (h <=? h')
      ( If
          (w <=? w')
          (b : c : w' : h' : '[])
          (TypeError (Text "Target width must be greater than current width!"))
      )
      (TypeError (Text "Target height must be greater than current height!"))
  Upsample2dCheck _ _ _ = TypeError (Text "Shape must be 4 dimensional!")

type Upsample2d shape h w = Upsample2dCheck shape h w

-- | Applies a 2D bilinear upsampling to an input signal composed of several input channels.
--
-- >>> (dtype &&& shape) $ upsample_bilinear2d @3 @5 False (ones :: CPUTensor 'D.Float '[2,3,2,2])
-- (Float,[2,3,3,5])
upsample_bilinear2d ::
  forall w h shape dtype device.
  (KnownNat h, KnownNat w, All KnownNat shape) =>
  -- | if True, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels.
  Bool ->
  Tensor device dtype shape ->
  Tensor device dtype (Upsample2d shape h w)
upsample_bilinear2d :: forall (w :: Nat) (h :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat h, KnownNat w, All KnownNat shape) =>
Bool
-> Tensor device dtype shape
-> Tensor device dtype (Upsample2d shape h w)
upsample_bilinear2d Bool
_align_corners Tensor device dtype shape
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.upsample_bilinear2d_tlb) Tensor device dtype shape
_input ([Int
w, Int
h] :: [Int]) Bool
_align_corners
  where
    w :: Int
w = forall (n :: Nat). KnownNat n => Int
natValI @w :: Int
    h :: Int
h = forall (n :: Nat). KnownNat n => Int
natValI @h :: Int

-- | Applies a 2D bicubic upsampling to an input signal composed of several input channels.
--
-- >>> (dtype &&& shape) $ upsample_bicubic2d @3 @5 False (ones :: CPUTensor 'D.Float '[2,3,2,2])
-- (Float,[2,3,3,5])
upsample_bicubic2d ::
  forall w h shape dtype device.
  (KnownNat h, KnownNat w, All KnownNat shape) =>
  Bool ->
  Tensor device dtype shape ->
  Tensor device dtype (Upsample2d shape h w)
upsample_bicubic2d :: forall (w :: Nat) (h :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat h, KnownNat w, All KnownNat shape) =>
Bool
-> Tensor device dtype shape
-> Tensor device dtype (Upsample2d shape h w)
upsample_bicubic2d Bool
_align_corners Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.upsample_bicubic2d_tlb) Tensor device dtype shape
_input ([Int
w, Int
h] :: [Int]) Bool
_align_corners
  where
    w :: Int
w = forall (n :: Nat). KnownNat n => Int
natValI @w :: Int
    h :: Int
h = forall (n :: Nat). KnownNat n => Int
natValI @h :: Int

-- upsample_trilinear3d :: Tensor device dtype shape -> (Int,Int,Int) -> Bool -> Tensor device dtype shape
-- upsample_trilinear3d _input _output_size _align_corners = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_trilinear3d_tlb) _input _output_size _align_corners

-- upsample_nearest1d :: Tensor device dtype shape -> Int -> Tensor device dtype shape
-- upsample_nearest1d _input _output_size = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest1d_tl) _input _output_size

-- | Applies a 2D bicubic upsampling to an input signal composed of several input channels.
--
-- >>> (dtype &&& shape) $ upsample_nearest2d @3 @5 (ones :: CPUTensor 'D.Float '[2,3,2,2])
-- (Float,[2,3,3,5])
upsample_nearest2d ::
  forall w h shape dtype device.
  (KnownNat h, KnownNat w, All KnownNat shape) =>
  Tensor device dtype shape ->
  Tensor device dtype (Upsample2d shape h w)
upsample_nearest2d :: forall (w :: Nat) (h :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat h, KnownNat w, All KnownNat shape) =>
Tensor device dtype shape
-> Tensor device dtype (Upsample2d shape h w)
upsample_nearest2d Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.upsample_nearest2d_tl) Tensor device dtype shape
_input ([Int
w, Int
h] :: [Int])
  where
    w :: Int
w = forall (n :: Nat). KnownNat n => Int
natValI @w :: Int
    h :: Int
h = forall (n :: Nat). KnownNat n => Int
natValI @h :: Int

-- upsample_nearest3d :: Tensor device dtype shape -> (Int,Int,Int) -> Tensor device dtype shape
-- upsample_nearest3d _input _output_size = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest3d_tl) _input _output_size

-- conv_dilated2d :: Tensor device dtype shape -> Tensor device dtype shape -> (Int,Int) -> Tensor device dtype shape -> (Int,Int) -> (Int,Int) -> (Int,Int) -> Tensor device dtype shape
-- conv_dilated2d _input _weight _kernel_size _bias _stride _padding _dilation = unsafePerformIO $ (ATen.cast7 ATen.Managed.conv_dilated2d_ttltlll) _input _weight _kernel_size _bias _stride _padding _dilation

-- conv_dilated3d :: Tensor device dtype shape -> Tensor device dtype shape -> (Int,Int,Int) -> Tensor device dtype shape -> (Int,Int,Int) -> (Int,Int,Int) -> (Int,Int,Int) -> Tensor device dtype shape
-- conv_dilated3d _input _weight _kernel_size _bias _stride _padding _dilation = unsafePerformIO $ (ATen.cast7 ATen.Managed.conv_dilated3d_ttltlll) _input _weight _kernel_size _bias _stride _padding _dilation

-- col2im :: Tensor device dtype shape -> (Int,Int) -> (Int,Int) -> (Int,Int) -> (Int,Int) -> (Int,Int) -> Tensor device dtype shape
-- col2im _input _output_size _kernel_size _dilation _padding _stride = unsafePerformIO $ (ATen.cast6 ATen.Managed.col2im_tlllll) _input _output_size _kernel_size _dilation _padding _stride

-- im2col :: Tensor device dtype shape -> (Int,Int) -> (Int,Int) -> (Int,Int) -> (Int,Int) -> Tensor device dtype shape
-- im2col _input _kernel_size _dilation _padding _stride = unsafePerformIO $ (ATen.cast5 ATen.Managed.im2col_tllll) _input _kernel_size _dilation _padding _stride