{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeApplications #-}

module Torch.Functional
  ( module Torch.Functional,
    Internal.acos,
    Internal.addmv,
    Internal.addr,
    Internal.allclose,
    Internal.argmin,
    Internal.asin,
    Internal.atan,
    Internal.baddbmm,
    Internal.bmm,
    Internal.conj,
    Internal.det,
    Internal.dot,
    Internal.einsum,
    Internal.expm1,
    Internal.ger,
    Internal.logdet,
    Internal.lstsq,
    Internal.mv,
    Internal.sumWithDimnames,
  )
where

import Data.Int
import Foreign.C.Types (CBool (..))
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.DType
import Torch.Dimname
import qualified Torch.Functional.Internal as Internal
import Torch.Internal.Cast
import Torch.Internal.Class
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Scalar as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.Tuple as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Scalar
import Torch.Tensor
-- import Torch.Functional.Internal hiding (argmax, clamp, cosh, conv1d, linear, softmax)
import Torch.TensorFactories (ones', onesLike)
import Prelude hiding
  ( acos,
    acosh,
    all,
    any,
    asin,
    asinh,
    atan,
    atanh,
    ceil,
    cos,
    cosh,
    exp,
    floor,
    isNaN,
    log,
    max,
    min,
    round,
    sin,
    sinh,
    tan,
    tanh,
  )
import qualified Prelude as P

kOne :: ForeignPtr ATen.Scalar
kOne :: ForeignPtr Scalar
kOne = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ CInt -> IO (ForeignPtr Scalar)
ATen.newScalar_i CInt
1
{-# NOINLINE kOne #-}

instance Num Tensor where
  + :: Tensor -> Tensor -> Tensor
(+) = Tensor -> Tensor -> Tensor
add
  (-) = Tensor -> Tensor -> Tensor
sub
  * :: Tensor -> Tensor -> Tensor
(*) = Tensor -> Tensor -> Tensor
mul
  negate :: Tensor -> Tensor
negate Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.neg_t Tensor
t
  abs :: Tensor -> Tensor
abs Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.abs_t Tensor
t
  signum :: Tensor -> Tensor
signum Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sign_t Tensor
t
  fromInteger :: Integer -> Tensor
fromInteger Integer
i = forall a. TensorLike a => a -> Tensor
asTensor @Int forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger @Int Integer
i

instance Eq Tensor where
  == :: Tensor -> Tensor -> Bool
(==) Tensor
t Tensor
t' = Tensor -> Bool
all (Tensor
t Tensor -> Tensor -> Tensor
`eq` Tensor
t')

instance Fractional Tensor where
  Tensor
a / :: Tensor -> Tensor -> Tensor
/ Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.div_tt Tensor
a Tensor
b
  recip :: Tensor -> Tensor
recip Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.reciprocal_t Tensor
t
  fromRational :: Rational -> Tensor
fromRational Rational
i = forall a. TensorLike a => a -> Tensor
asTensor @Float forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => Rational -> a
fromRational @Float Rational
i

-- Return upper or lower triangular matrices
data Tri = Upper | Lower deriving (Tri -> Tri -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tri -> Tri -> Bool
$c/= :: Tri -> Tri -> Bool
== :: Tri -> Tri -> Bool
$c== :: Tri -> Tri -> Bool
Eq, Int -> Tri -> ShowS
[Tri] -> ShowS
Tri -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tri] -> ShowS
$cshowList :: [Tri] -> ShowS
show :: Tri -> String
$cshow :: Tri -> String
showsPrec :: Int -> Tri -> ShowS
$cshowsPrec :: Int -> Tri -> ShowS
Show)

-- Reductions, used by BCE loss, see -
-- https://github.com/pytorch/pytorch/blob/3762cf9cc63e2032410d50f218c1406668177c23/aten/src/ATen/core/Reduction.h
data Reduction = ReduceNone | ReduceMean | ReduceSum deriving (Reduction -> Reduction -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reduction -> Reduction -> Bool
$c/= :: Reduction -> Reduction -> Bool
== :: Reduction -> Reduction -> Bool
$c== :: Reduction -> Reduction -> Bool
Eq, Int -> Reduction -> ShowS
[Reduction] -> ShowS
Reduction -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reduction] -> ShowS
$cshowList :: [Reduction] -> ShowS
show :: Reduction -> String
$cshow :: Reduction -> String
showsPrec :: Int -> Reduction -> ShowS
$cshowsPrec :: Int -> Reduction -> ShowS
Show)

newtype Dim = Dim Int

data KeepDim = KeepDim | RemoveDim deriving (KeepDim -> KeepDim -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KeepDim -> KeepDim -> Bool
$c/= :: KeepDim -> KeepDim -> Bool
== :: KeepDim -> KeepDim -> Bool
$c== :: KeepDim -> KeepDim -> Bool
Eq, Int -> KeepDim -> ShowS
[KeepDim] -> ShowS
KeepDim -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KeepDim] -> ShowS
$cshowList :: [KeepDim] -> ShowS
show :: KeepDim -> String
$cshow :: KeepDim -> String
showsPrec :: Int -> KeepDim -> ShowS
$cshowsPrec :: Int -> KeepDim -> ShowS
Show)

data CeilMode = Ceil | Floor deriving (CeilMode -> CeilMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CeilMode -> CeilMode -> Bool
$c/= :: CeilMode -> CeilMode -> Bool
== :: CeilMode -> CeilMode -> Bool
$c== :: CeilMode -> CeilMode -> Bool
Eq, Int -> CeilMode -> ShowS
[CeilMode] -> ShowS
CeilMode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CeilMode] -> ShowS
$cshowList :: [CeilMode] -> ShowS
show :: CeilMode -> String
$cshow :: CeilMode -> String
showsPrec :: Int -> CeilMode -> ShowS
$cshowsPrec :: Int -> CeilMode -> ShowS
Show)

instance Castable CeilMode CBool where -- Word8 == CBool
  cast :: forall r. CeilMode -> (CBool -> IO r) -> IO r
cast CeilMode
Ceil CBool -> IO r
f = CBool -> IO r
f CBool
1
  cast CeilMode
Floor CBool -> IO r
f = CBool -> IO r
f CBool
0
  uncast :: forall r. CBool -> (CeilMode -> IO r) -> IO r
uncast CBool
0 CeilMode -> IO r
f = CeilMode -> IO r
f CeilMode
Floor
  uncast CBool
1 CeilMode -> IO r
f = CeilMode -> IO r
f CeilMode
Ceil

instance Castable Reduction Int64 where
  cast :: forall r. Reduction -> (Int64 -> IO r) -> IO r
cast Reduction
ReduceNone Int64 -> IO r
f = Int64 -> IO r
f Int64
0
  cast Reduction
ReduceMean Int64 -> IO r
f = Int64 -> IO r
f Int64
1
  cast Reduction
ReduceSum Int64 -> IO r
f = Int64 -> IO r
f Int64
2
  uncast :: forall r. Int64 -> (Reduction -> IO r) -> IO r
uncast Int64
0 Reduction -> IO r
f = Reduction -> IO r
f Reduction
ReduceNone
  uncast Int64
1 Reduction -> IO r
f = Reduction -> IO r
f Reduction
ReduceMean
  uncast Int64
_ Reduction -> IO r
f = Reduction -> IO r
f Reduction
ReduceSum

newtype Diag = Diag Int

isUpper :: Tri -> Bool
isUpper Tri
Upper = Bool
True
isUpper Tri
Lower = Bool
False

-- | Returns the mean value of all elements in the input tensor.
mean ::
  -- | input
  Tensor ->
  -- | output
  Tensor
mean :: Tensor -> Tensor
mean Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mean_t Tensor
t

-- | Returns the standard deviation of all elements in the input tensor.
std ::
  -- | input
  Tensor ->
  -- | output
  Tensor
std :: Tensor -> Tensor
std Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.std_t Tensor
t

-- | Returns the variance of all elements in the input tensor.
var ::
  -- | input
  Tensor ->
  -- | output
  Tensor
var :: Tensor -> Tensor
var Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.var_t Tensor
t

-- | Returns the sum of all elements in the input tensor.
sumAll ::
  -- | input
  Tensor ->
  -- | output
  Tensor
sumAll :: Tensor -> Tensor
sumAll Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sum_t Tensor
t

-- | Computes the element-wise absolute value of the given input tensor.
abs ::
  -- | input
  Tensor ->
  -- | output
  Tensor
abs :: Tensor -> Tensor
abs Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.abs_t Tensor
t

-- | Computes the fractional portion of each element in input.
-- out_i = input_i - (floor . abs) input_i * (sign input_i)
frac ::
  -- | input
  Tensor ->
  -- | output
  Tensor
frac :: Tensor -> Tensor
frac Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.frac_t Tensor
_self

keepdim :: KeepDim -> Bool
keepdim KeepDim
KeepDim = Bool
True
keepdim KeepDim
RemoveDim = Bool
False

-- | Returns the indices of the maximum value of all elements in the input tensor.
argmax ::
  -- | the dimension to reduce
  Dim ->
  -- | whether the output tensor has dim retained or not
  KeepDim ->
  -- | input
  Tensor ->
  -- | output
  Tensor
argmax :: Dim -> KeepDim -> Tensor -> Tensor
argmax (Dim Int
d) KeepDim
k Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.argmax_tlb Tensor
t Int
d (KeepDim -> Bool
keepdim KeepDim
k)

-- | Each element of the tensor other added to each element of the tensor input. The resulting tensor is returned.
add ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
add :: Tensor -> Tensor -> Tensor
add Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.add_tts Tensor
a Tensor
b ForeignPtr Scalar
kOne

-- | Multiplies each element of the tensor other to each element of the input tensor and returns a new resulting tensor.
mul ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
mul :: Tensor -> Tensor -> Tensor
mul Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mul_tt Tensor
a Tensor
b

-- | Element wise subtraction of other tensor from input tensor and returns a new resulting tensor
sub ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
sub :: Tensor -> Tensor -> Tensor
sub Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.sub_tts Tensor
a Tensor
b ForeignPtr Scalar
kOne

-- | Element wise division of input tensor by other tensor and returns a new resulting tensor
div ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  Tensor
div :: Tensor -> Tensor -> Tensor
div Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.div_tt Tensor
a Tensor
b

-- | ceil
ceil ::
  -- | input
  Tensor ->
  -- | output
  Tensor
ceil :: Tensor -> Tensor
ceil Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.ceil_t Tensor
t

-- | floor
floor ::
  -- | input
  Tensor ->
  -- | output
  Tensor
floor :: Tensor -> Tensor
floor Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.floor_t Tensor
t

-- | min
min ::
  -- | input
  Tensor ->
  -- | output
  Tensor
min :: Tensor -> Tensor
min Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.min_t Tensor
t

-- | max
max ::
  -- | input
  Tensor ->
  -- | output
  Tensor
max :: Tensor -> Tensor
max Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.max_t Tensor
t

-- | median
median ::
  -- | input
  Tensor ->
  -- | output
  Tensor
median :: Tensor -> Tensor
median Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.median_t Tensor
t

-- | Adds each element of the input input with the scalar and returns a new resulting tensor.
addScalar ::
  Scalar a =>
  -- | summand
  a ->
  -- | input
  Tensor ->
  -- | output
  Tensor
addScalar :: forall a. Scalar a => a -> Tensor -> Tensor
addScalar a
a Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.add_ts Tensor
t a
a

-- | Subtracts each element of the input input with the scalar and returns a new resulting tensor.
subScalar ::
  Scalar a =>
  -- | subtrahend
  a ->
  -- | input
  Tensor ->
  -- | output
  Tensor
subScalar :: forall a. Scalar a => a -> Tensor -> Tensor
subScalar a
a Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.sub_ts Tensor
t a
a

-- | Multiplies each element of the input input with the scalar and returns a new resulting tensor.
mulScalar ::
  Scalar a =>
  -- | multiplier
  a ->
  -- | input
  Tensor ->
  -- | output
  Tensor
mulScalar :: forall a. Scalar a => a -> Tensor -> Tensor
mulScalar a
a Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.mul_ts Tensor
t a
a

-- | Divides each element of the input input with the scalar and returns a new resulting tensor.
divScalar ::
  Scalar a =>
  -- | divisor
  a ->
  -- | input
  Tensor ->
  -- | output
  Tensor
divScalar :: forall a. Scalar a => a -> Tensor -> Tensor
divScalar a
a Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.div_ts Tensor
t a
a

-- |  Matrix product of two tensors.
--
-- The behavior depends on the dimensionality of the tensors as follows:
--
-- If both tensors are 1-dimensional, the dot product (scalar) is returned.
-- If both arguments are 2-dimensional, the matrix-matrix product is returned.
-- If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
-- If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
-- If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if input is a (j \times 1 \times n \times m)(j×1×n×m) tensor and other is a (k \times m \times p)(k×m×p) tensor, out will be an (j \times k \times n \times p)(j×k×n×p) tensor.
matmul ::
  -- | first tensor for matrix multiplication
  Tensor ->
  -- | second tensor for matrix multiplication
  Tensor ->
  -- | output
  Tensor
matmul :: Tensor -> Tensor -> Tensor
matmul Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.matmul_tt Tensor
a Tensor
b

-- | A simple lookup table that looks up embeddings in a fixed dictionary and size.
-- This module is often used to retrieve word embeddings using indices. The input to the module is a list of indices, and the embedding matrix, and the output is the corresponding word embeddings.
embedding ::
  -- | whether or not to scale the gradient by the frequencies
  Bool ->
  -- | whether or not the embedding is sparse
  Bool ->
  -- | weights
  Tensor ->
  -- | padding
  Int ->
  -- | indices
  Tensor ->
  -- | output
  Tensor
embedding :: Bool -> Bool -> Tensor -> Int -> Tensor -> Tensor
embedding Bool
scaleByGradFreq Bool
sparse Tensor
weights Int
paddingIdx Tensor
indices =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.embedding_ttlbb
      Tensor
weights
      Tensor
indices
      Int
paddingIdx
      Bool
scaleByGradFreq
      Bool
sparse

embedding' ::
  -- | weights
  Tensor ->
  -- | indices
  Tensor ->
  -- | output
  Tensor
embedding' :: Tensor -> Tensor -> Tensor
embedding' Tensor
weights Tensor
indices =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.embedding_ttlbb
      Tensor
weights
      Tensor
indices
      (-Int
1 :: Int)
      Bool
False
      Bool
False

-- | A one hot encoding of the given input. The encoding is based on the given number of
-- classes.
oneHot ::
  -- | number of classes
  Int ->
  -- | input
  Tensor ->
  Tensor
oneHot :: Int -> Tensor -> Tensor
oneHot Int
numClasses Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.one_hot_tl Tensor
t Int
numClasses

--
-- element-wise transformations / non-linearities
--

-- | Computes the error function of each element
erf ::
  -- | input
  Tensor ->
  -- | output
  Tensor
erf :: Tensor -> Tensor
erf Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.erf_t Tensor
t

-- | Computes the complementary error function of each element of input
erfc ::
  -- | input
  Tensor ->
  -- | output
  Tensor
erfc :: Tensor -> Tensor
erfc Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.erfc_t Tensor
t

-- | Computes the inverse error function of each element of input. The inverse error function is defined in the range (-1, 1)(−1,1) as: erfinv(erf(x)) = x
erfinv ::
  -- | input
  Tensor ->
  -- | output
  Tensor
erfinv :: Tensor -> Tensor
erfinv Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.erfinv_t Tensor
t

-- | Computes the logarithm of the gamma function on input.
lgamma ::
  -- | input
  Tensor ->
  -- | output
  Tensor
lgamma :: Tensor -> Tensor
lgamma Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.lgamma_t Tensor
t

-- | Computes the logarithmic derivative of the gamma function on input.
digamma ::
  -- | input
  Tensor ->
  -- | output
  Tensor
digamma :: Tensor -> Tensor
digamma Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.digamma_t Tensor
t

-- | Computes the nth derivative of the digamma function on input. n \geq 0n≥0 is called the order of the polygamma function.
polygamma ::
  -- | n
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
polygamma :: Int -> Tensor -> Tensor
polygamma Int
n Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.polygamma_lt Int
n Tensor
t

-- | Computes the multivariate log-gamma function with dimension pp element-wise. All elements must be greater than (p-1)/2, otherwise an error would be thrown.
mvlgamma ::
  -- | p
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
mvlgamma :: Int -> Tensor -> Tensor
mvlgamma Int
p Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.mvlgamma_tl Tensor
t Int
p

-- | Returns a new tensor with the exponential of the elements of the input tensor input.
exp ::
  -- | input
  Tensor ->
  -- | output
  Tensor
exp :: Tensor -> Tensor
exp Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.exp_t Tensor
t

-- | Returns a new tensor with the natural logarithm of (1 + input).
log1p ::
  Tensor -> Tensor
log1p :: Tensor -> Tensor
log1p Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.log1p_t Tensor
t

-- | Returns a new tensor with the logarithm to the base 2 of the elements of input.
log2 ::
  -- | input
  Tensor ->
  -- | output
  Tensor
log2 :: Tensor -> Tensor
log2 Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.log2_t Tensor
t

-- | Returns a new tensor with the natural logarithm of the elements of input.
log ::
  -- | input
  Tensor ->
  -- | output
  Tensor
log :: Tensor -> Tensor
log Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.log_t Tensor
_self

-- | Returns a new tensor with the logarithm to the base 10 of the elements of input.
log10 ::
  -- | input
  Tensor ->
  -- | output
  Tensor
log10 :: Tensor -> Tensor
log10 Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.log10_t Tensor
t

-- | Takes the power of each element in input with exponent and returns a tensor with the result.
pow ::
  Scalar a =>
  -- | exponent
  a ->
  -- | input
  Tensor ->
  -- | output
  Tensor
pow :: forall a. Scalar a => a -> Tensor -> Tensor
pow a
s Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.pow_ts Tensor
t a
s

-- | Takes the power of each element in input with exponent and returns a tensor with the result.
-- Exponent is a tensor with the same number of elements as input.
powt ::
  -- | input
  Tensor ->
  -- | exponent
  Tensor ->
  -- | output
  Tensor
powt :: Tensor -> Tensor -> Tensor
powt Tensor
t Tensor
t' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.pow_tt Tensor
t Tensor
t'

-- | Applies the rectified linear unit function element-wise.
relu ::
  -- | input
  Tensor ->
  -- | output
  Tensor
relu :: Tensor -> Tensor
relu Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.relu_t Tensor
t

-- | Applies Exponential linear unit function element-wise, with alpha input, \(\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))\)
elu ::
  Scalar s =>
  -- | alpha value for ELU formulation
  s ->
  -- | input
  Tensor ->
  -- | output
  Tensor
elu :: forall a. Scalar a => a -> Tensor -> Tensor
elu s
a Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.elu_ts Tensor
t s
a

-- | Applies exponential linear unit function element wise with default alpha value = 1
elu' ::
  -- | input
  Tensor ->
  -- | output
  Tensor
elu' :: Tensor -> Tensor
elu' Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.elu_t Tensor
t

-- | Applies element-wise, \(\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))\) , with α=1.6732632423543772848170429916717 and scale=1.0507009873554804934193349852946.
selu ::
  -- | input
  Tensor ->
  -- | output
  Tensor
selu :: Tensor -> Tensor
selu Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.selu_t Tensor
t

-- | Applies element-wise, \(\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))\).
celu ::
  -- | alpha
  Float ->
  -- | input
  Tensor ->
  -- | output
  Tensor
celu :: Float -> Tensor -> Tensor
celu Float
_alpha Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.celu_ts Tensor
_self Float
_alpha

-- | Applies the element-wise function sigmoid.
sigmoid ::
  -- | input
  Tensor ->
  -- | output
  Tensor
sigmoid :: Tensor -> Tensor
sigmoid Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sigmoid_t Tensor
t

-- | Applies a softmax function.
-- It is applied to all slices along dim, and will re-scale them so that the elements lie in the range [0, 1] and sum to 1.
softmax ::
  -- | dimension
  Dim ->
  -- | input
  Tensor ->
  -- | output
  Tensor
softmax :: Dim -> Tensor -> Tensor
softmax (Dim Int
d) Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3
      ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.softmax_tls
      Tensor
input
      Int
d
      (Tensor -> DType
dtype Tensor
input)

-- | Applies a softmax followed by a logarithm.
-- While mathematically equivalent to log(softmax(x)), doing these two operations separately is slower, and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly.
logSoftmax ::
  -- | dimension
  Dim ->
  -- | input
  Tensor ->
  -- | output
  Tensor
logSoftmax :: Dim -> Tensor -> Tensor
logSoftmax (Dim Int
d) Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3
      ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.log_softmax_tls
      Tensor
input
      Int
d
      (Tensor -> DType
dtype Tensor
input)

-- | Thresholds each element of the input Tensor.
threshold ::
  -- | threshold
  Float ->
  -- | value
  Float ->
  -- | input
  Tensor ->
  -- | output
  Tensor
threshold :: Float -> Float -> Tensor -> Tensor
threshold Float
threshold Float
value Tensor
self =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.threshold_tss Tensor
self Float
threshold Float
value

-- | Returns a new tensor with the sine of the elements of input.
sin ::
  -- | input
  Tensor ->
  -- | output
  Tensor
sin :: Tensor -> Tensor
sin Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sin_t Tensor
t

-- | Returns a new tensor with the cos of the elements of input.
cos ::
  -- | input
  Tensor ->
  -- | output
  Tensor
cos :: Tensor -> Tensor
cos Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.cos_t Tensor
t

-- | Returns a new tensor with the tangent of the elements of input.
tan ::
  -- | input
  Tensor ->
  -- | output
  Tensor
tan :: Tensor -> Tensor
tan Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tan_t Tensor
t

-- | Returns a new tensor with the hyperbolic sine of the elements of input.
sinh ::
  -- | input
  Tensor ->
  -- | output
  Tensor
sinh :: Tensor -> Tensor
sinh Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sinh_t Tensor
t

-- | Returns a new tensor with the hyperbolic cosine of the elements of input.
cosh ::
  -- | input
  Tensor ->
  -- | output
  Tensor
cosh :: Tensor -> Tensor
cosh Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.cosh_t Tensor
t

-- | Returns a new tensor with the hyperbolic tangent of the elements of input.
tanh ::
  -- | input
  Tensor ->
  -- | output
  Tensor
tanh :: Tensor -> Tensor
tanh Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tanh_t Tensor
t

-- | Returns a new tensor with the square-root of the elements of input.
sqrt ::
  -- | input
  Tensor ->
  -- | output
  Tensor
sqrt :: Tensor -> Tensor
sqrt Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sqrt_t Tensor
t

--
-- infix operators
--

-- | Computes input > other element-wise.
-- The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
gt ::
  -- | input
  Tensor ->
  -- | output
  Tensor ->
  -- | other
  Tensor
gt :: Tensor -> Tensor -> Tensor
gt Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.gt_tt Tensor
a Tensor
b

>. :: Tensor -> Tensor -> Tensor
(>.) = Tensor -> Tensor -> Tensor
gt

-- | Computes input < other element-wise.
-- The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
lt ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
lt :: Tensor -> Tensor -> Tensor
lt Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.lt_tt Tensor
a Tensor
b

<. :: Tensor -> Tensor -> Tensor
(<.) = Tensor -> Tensor -> Tensor
lt

-- | Computes input >= other element-wise.
-- The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
ge ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
ge :: Tensor -> Tensor -> Tensor
ge Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.ge_tt Tensor
a Tensor
b

>=. :: Tensor -> Tensor -> Tensor
(>=.) = Tensor -> Tensor -> Tensor
ge

-- | Computes input <= other element-wise.
-- The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
le ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
le :: Tensor -> Tensor -> Tensor
le Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.le_tt Tensor
a Tensor
b

<=. :: Tensor -> Tensor -> Tensor
(<=.) = Tensor -> Tensor -> Tensor
le

-- | Computes input == other element-wise.
-- The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
eq ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
eq :: Tensor -> Tensor -> Tensor
eq Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.eq_tt Tensor
a Tensor
b

==. :: Tensor -> Tensor -> Tensor
(==.) = Tensor -> Tensor -> Tensor
eq

-- | Returns a new tensor with the elements of input at the given indices. The input tensor is treated as if it were viewed as a 1-D tensor. The result takes the same shape as the indices.
take ::
  -- | index
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
take :: Tensor -> Tensor -> Tensor
take Tensor
_index Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.take_tt Tensor
_self Tensor
_index

-- | Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
-- The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.
maskedSelect ::
  -- | mask
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
maskedSelect :: Tensor -> Tensor -> Tensor
maskedSelect Tensor
_mask Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.masked_select_tt Tensor
_self Tensor
_mask

-- | Returns a tuple of 1-D tensors, one for each dimension in input, each containing the indices (in that dimension) of all non-zero elements of input .
nonzero ::
  -- | input
  Tensor ->
  -- | output
  Tensor
nonzero :: Tensor -> Tensor
nonzero Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.nonzero_t Tensor
_self

isclose ::
  -- | rtol
  Double ->
  -- | atol
  Double ->
  -- | equal_nan
  Bool ->
  -- | self
  Tensor ->
  -- | other
  Tensor ->
  Tensor
isclose :: Double -> Double -> Bool -> Tensor -> Tensor -> Tensor
isclose Double
rtol Double
atol Bool
equalNan Tensor
self Tensor
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> CDouble
-> CBool
-> IO (ForeignPtr Tensor)
ATen.isclose_ttddb Tensor
self Tensor
other Double
rtol Double
atol Bool
equalNan

isnan ::
  -- | self
  Tensor ->
  Tensor -- a new tensor with boolean elements representing if each element is NaN or not.
isnan :: Tensor -> Tensor
isnan Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.isnan_t Tensor
t

isNonzero ::
  -- | self
  Tensor ->
  Bool
isNonzero :: Tensor -> Bool
isNonzero Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.is_nonzero_t Tensor
_self

isSameSize ::
  -- | self
  Tensor ->
  -- | other
  Tensor ->
  Bool
isSameSize :: Tensor -> Tensor -> Bool
isSameSize Tensor
self Tensor
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO CBool
ATen.is_same_size_tt Tensor
self Tensor
other

isSigned ::
  -- | input
  Tensor ->
  -- | True if the data type of input is a signed type
  Bool
isSigned :: Tensor -> Bool
isSigned Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.is_signed_t Tensor
t

-- | Computes input /= other element-wise.
-- The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
ne ::
  -- | input
  Tensor ->
  -- | other
  Tensor ->
  -- | output
  Tensor
ne :: Tensor -> Tensor -> Tensor
ne Tensor
a Tensor
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.ne_tt Tensor
a Tensor
b

/=. :: Tensor -> Tensor -> Tensor
(/=.) = Tensor -> Tensor -> Tensor
ne

-- | Casting to given 'Dtype', where 'Dtype' is an object that represents the data type of a tensor in hasktorch.
toDType ::
  -- | data type to cast to
  DType ->
  -- | input
  Tensor ->
  -- | output
  Tensor
toDType :: DType -> Tensor -> Tensor
toDType DType
dtype Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ScalarType -> CBool -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_to_sbb Tensor
t DType
dtype Bool
False Bool
False

-- | squeezeAll
squeezeAll ::
  -- | input
  Tensor ->
  -- | output
  Tensor
squeezeAll :: Tensor -> Tensor
squeezeAll Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.squeeze_t Tensor
t

-- | squeezeDim
squeezeDim ::
  -- | dim
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
squeezeDim :: Int -> Tensor -> Tensor
squeezeDim Int
dim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.squeeze_tl Tensor
t Int
dim

--
-- Cumulative operations
--

-- | Returns a tuple (values, indices) where values is the cumulative maximum of elements of input in the dimension dim. And indices is the index location of each maximum value found in the dimension dim.
cummax ::
  -- | dim
  Int ->
  -- | input
  Tensor ->
  -- | output (values, indices)
  (Tensor, Tensor)
cummax :: Int -> Tensor -> (Tensor, Tensor)
cummax Int
_dim Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.cummax_tl Tensor
_self Int
_dim

-- | Returns a tuple (values, indices) where values is the cumulative minimum of elements of input in the dimension dim. And indices is the index location of each maximum value found in the dimension dim.
cummin ::
  -- | dim
  Int ->
  -- | input
  Tensor ->
  -- | output (values, indices)
  (Tensor, Tensor)
cummin :: Int -> Tensor -> (Tensor, Tensor)
cummin Int
_dim Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.cummin_tl Tensor
_self Int
_dim

-- | Returns the cumulative product of elements of input in the dimension dim.
-- For example, if input is a vector of size N, the result will also be a vector of size N, with elements.
cumprod ::
  -- | dim
  Int ->
  -- | dtype
  DType ->
  -- | input
  Tensor ->
  -- | output
  Tensor
cumprod :: Int -> DType -> Tensor -> Tensor
cumprod Int
_dim DType
_dtype Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.cumprod_tls Tensor
_self Int
_dim DType
_dtype

-- | Returns the cumulative sum of elements of input in the dimension dim.
-- For example, if input is a vector of size N, the result will also be a vector of size N, with elements.
cumsum ::
  -- | dim
  Int ->
  -- | dtype
  DType ->
  -- | input
  Tensor ->
  -- | output
  Tensor
cumsum :: Int -> DType -> Tensor -> Tensor
cumsum Int
_dim DType
_dtype Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.cumsum_tls Tensor
_self Int
_dim DType
_dtype

--
-- Loss Functions
--

-- | Function that measures the Binary Cross Entropy between the target and the output.
binaryCrossEntropyLoss ::
  -- | Specifies the reduction to apply to the output
  Reduction ->
  -- | target
  Tensor ->
  -- | weight
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
binaryCrossEntropyLoss :: Reduction -> Tensor -> Tensor -> Tensor -> Tensor
binaryCrossEntropyLoss Reduction
reduction Tensor
target Tensor
weight Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.binary_cross_entropy_tttl Tensor
t Tensor
target Tensor
weight Reduction
reduction

-- | Binary Cross Entropy with weights defaulted to 1.0 & reduction defaulted to ReduceMean
binaryCrossEntropyLoss' ::
  -- | target
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
binaryCrossEntropyLoss' :: Tensor -> Tensor -> Tensor
binaryCrossEntropyLoss' Tensor
target Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.binary_cross_entropy_tttl Tensor
t Tensor
target (Tensor -> Tensor
onesLike Tensor
target) Reduction
ReduceMean

-- | This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.
binaryCrossEntropyWithLogits ::
  -- | Specifies the reduction to apply to the output
  Reduction ->
  -- | target
  Tensor ->
  -- | weight
  Tensor ->
  -- | pos_weight
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
binaryCrossEntropyWithLogits :: Reduction -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
binaryCrossEntropyWithLogits Reduction
reduction Tensor
target Tensor
weight Tensor
pos_weight Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.binary_cross_entropy_with_logits_ttttl Tensor
input Tensor
target Tensor
weight Tensor
pos_weight Reduction
reduction

-- | Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the @input@ and @target@.
mseLoss ::
  -- | target tensor
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
mseLoss :: Tensor -> Tensor -> Tensor
mseLoss Tensor
target Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.mse_loss_ttl Tensor
t Tensor
target Int64
ATen.kMean

-- | The negative log likelihood loss.
nllLoss' ::
  -- | target tensor
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
nllLoss' :: Tensor -> Tensor -> Tensor
nllLoss' Tensor
target Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.nll_loss_tttll Tensor
t Tensor
target Tensor
weight Reduction
ReduceMean (-Int
100 :: Int)
  where
    nClass :: Int
nClass = Tensor -> [Int]
shape Tensor
t forall a. [a] -> Int -> a
!! Int
1 -- TODO: nicer runtime error if input dimensions don't conform
    weight :: Tensor
weight = DType -> Tensor -> Tensor
toDType (Tensor -> DType
dtype Tensor
t) forall a b. (a -> b) -> a -> b
$ Device -> Tensor -> Tensor
_toDevice (Tensor -> Device
device Tensor
target) forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor
ones' [Int
nClass]

-- | Returns cosine similarity between x1 and x2, computed along dim.
cosineSimilarity ::
  -- | dimension of vectors (default=1)
  Dim ->
  -- | small value to avoid division by 0 (default=1e-8)
  Double ->
  -- | x1
  Tensor ->
  -- | x2
  Tensor ->
  -- | output
  Tensor
cosineSimilarity :: Dim -> Double -> Tensor -> Tensor -> Tensor
cosineSimilarity (Dim Int
dim) Double
eps Tensor
x1 Tensor
x2 =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> CDouble -> IO (ForeignPtr Tensor)
ATen.cosine_similarity_ttld Tensor
x1 Tensor
x2 Int
dim Double
eps

-- | Returns cosine similarity with defaulted options.
cosineSimilarity' ::
  -- | x1
  Tensor ->
  -- | x2
  Tensor ->
  -- | output
  Tensor
cosineSimilarity' :: Tensor -> Tensor -> Tensor
cosineSimilarity' Tensor
x1 Tensor
x2 =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> CDouble -> IO (ForeignPtr Tensor)
ATen.cosine_similarity_ttld Tensor
x1 Tensor
x2 (Int
1 :: Int) (Double
1e-8 :: Double)

-- | The Connectionist Temporal Classification loss.
-- Calculates loss between a continuous (unsegmented) time series and a target sequence.
-- CTCLoss sums over the probability of possible alignments of input to target,
-- producing a loss value which is differentiable with respect to each input node.
-- The alignment of input to target is assumed to be “many-to-one”, which limits
-- the length of the target sequence such that it must be \leq≤ the input length.
ctcLoss ::
  -- | zero_infinity - Whether to zero infinite losses and the associated gradients (False by default). Infinite losses mainly occur when the inputs are too short to be aligned to the targets.
  Bool ->
  -- | blank label
  Int ->
  -- | reduction
  Reduction ->
  -- | input_lengths
  [Int] ->
  -- | target_lengths
  [Int] ->
  -- | log_probs
  Tensor ->
  -- | targets
  Tensor ->
  -- | output
  Tensor
ctcLoss :: Bool
-> Int -> Reduction -> [Int] -> [Int] -> Tensor -> Tensor -> Tensor
ctcLoss Bool
zeroInfinity Int
blank Reduction
reduction [Int]
inputLengths [Int]
targetLengths Tensor
logProbs Tensor
targets = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> Int64
-> CBool
-> IO (ForeignPtr Tensor)
ATen.ctc_loss_ttllllb Tensor
logProbs Tensor
targets [Int]
inputLengths [Int]
targetLengths Int
blank Reduction
reduction Bool
zeroInfinity

-- | Returns CTC loss with defaulted options.
ctcLoss' ::
  -- | reduction
  Reduction ->
  -- | input lengths
  [Int] ->
  -- | target lengths
  [Int] ->
  -- | log probs
  Tensor ->
  -- | targets
  Tensor ->
  -- | output
  Tensor
ctcLoss' :: Reduction -> [Int] -> [Int] -> Tensor -> Tensor -> Tensor
ctcLoss' Reduction
reduction [Int]
inputLengths [Int]
targetLengths Tensor
logProbs Tensor
targets = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> Int64
-> CBool
-> IO (ForeignPtr Tensor)
ATen.ctc_loss_ttllllb Tensor
logProbs Tensor
targets [Int]
inputLengths [Int]
targetLengths Int
blank Reduction
reduction Bool
zeroInfinity
  where
    blank :: Int
blank = Int
0 :: Int
    zeroInfinity :: Bool
zeroInfinity = Bool
False

-- | Returns the p-norm of (input - other)
-- The shapes of input and other must be broadcastable.
dist ::
  -- | p
  Float ->
  -- | other
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
dist :: Float -> Tensor -> Tensor -> Tensor
dist Float
_p Tensor
_other Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.dist_tts Tensor
_self Tensor
_other Float
_p

-- | Measures the loss given an input tensor xx and a labels tensor yy (containing 1 or -1).
-- This is usually used for measuring whether two inputs are similar or dissimilar,
-- e.g. using the L1 pairwise distance as xx,
-- and is typically used for learning nonlinear embeddings or semi-supervised learning.
hingeEmbeddingLoss ::
  -- | margin
  Double ->
  -- | reduction
  Reduction ->
  -- | target
  Tensor ->
  -- | self
  Tensor ->
  -- | output
  Tensor
hingeEmbeddingLoss :: Double -> Reduction -> Tensor -> Tensor -> Tensor
hingeEmbeddingLoss Double
margin Reduction
reduction Tensor
target Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor -> CDouble -> Int64 -> IO (ForeignPtr Tensor)
ATen.hinge_embedding_loss_ttdl Tensor
t Tensor
target Double
margin Reduction
reduction

marginRankingLoss ::
  -- | input1
  Tensor ->
  -- | input2
  Tensor ->
  -- | target
  Tensor ->
  -- | margin
  Double ->
  -- | reduction
  Reduction ->
  -- | output
  Tensor
marginRankingLoss :: Tensor -> Tensor -> Tensor -> Double -> Reduction -> Tensor
marginRankingLoss Tensor
input1 Tensor
input2 Tensor
target Double
margin Reduction
reduction = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> Int64
-> IO (ForeignPtr Tensor)
ATen.margin_ranking_loss_tttdl Tensor
input1 Tensor
input2 Tensor
target Double
margin Reduction
reduction

-- | The 2D negative log likelihood loss
nllLoss2D ::
  Reduction -> -- reduction
  Int -> -- ignore_index
  Tensor -> -- input
  Tensor -> -- target
  Tensor -> -- weight
  Tensor -- output
nllLoss2D :: Reduction -> Int -> Tensor -> Tensor -> Tensor -> Tensor
nllLoss2D Reduction
reduction Int
ignoreindex Tensor
input Tensor
target Tensor
weight = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.nll_loss2d_tttll Tensor
input Tensor
target Tensor
weight Reduction
reduction Int
ignoreindex

-- | Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input \(x\) (a 2D mini-batch Tensor) and output \(y\) (which is a 1D tensor of target class indices)
multiMarginLoss ::
  -- | reduction
  Reduction ->
  -- | p
  Float ->
  -- | margin
  Float ->
  -- | input
  Tensor ->
  -- | target
  Tensor ->
  -- | weight
  Tensor ->
  -- | output
  Tensor
multiMarginLoss :: Reduction -> Float -> Float -> Tensor -> Tensor -> Tensor -> Tensor
multiMarginLoss Reduction
reduction Float
p Float
margin Tensor
input Tensor
target Tensor
weight = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.multi_margin_loss_ttsstl Tensor
input Tensor
target Float
p Float
margin Tensor
weight Reduction
reduction

-- | Creates a criterion that optimizes a multi-label one-versus-all loss based on max-entropy, between input \(x\) and target \(y\) of size \((N,C)\) .
multiLabelMarginLoss ::
  Reduction -> -- reduction
  Tensor -> -- input
  Tensor -> -- target
  Tensor -- output
multiLabelMarginLoss :: Reduction -> Tensor -> Tensor -> Tensor
multiLabelMarginLoss Reduction
reduction Tensor
input Tensor
target = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.multilabel_margin_loss_ttl Tensor
input Tensor
target Reduction
reduction

-- | The Kullback-Leibler divergence Loss
-- KL divergence is a useful distance measure for continuous distributions and is often useful when performing direct regression over the space of (discretely sampled) continuous output distributions.
-- As with NLLLoss, the input given is expected to contain log-probabilities and is not restricted to a 2D Tensor. The targets are interpreted as probabilities by default, but could be considered as log-probabilities with log_target set to True.
-- This criterion expects a target Tensor of the same size as the input Tensor.
klDiv ::
  Reduction ->
  -- | self
  Tensor ->
  -- | target
  Tensor ->
  -- | output
  Tensor
klDiv :: Reduction -> Tensor -> Tensor -> Tensor
klDiv Reduction
reduction Tensor
self Tensor
target = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.kl_div_ttl Tensor
self Tensor
target Reduction
reduction

-- | Creates a criterion that uses a squared term if the absolute element-wise
--  error falls below 1 and an L1 term otherwise. It is less sensitive to
-- outliers than the MSELoss and in some cases prevents exploding gradients
-- (e.g. see Fast R-CNN paper by Ross Girshick). Also known as the Huber loss.
smoothL1Loss ::
  -- | reduction
  Reduction ->
  -- | self
  Tensor ->
  -- | target
  Tensor ->
  -- | output
  Tensor
smoothL1Loss :: Reduction -> Tensor -> Tensor -> Tensor
smoothL1Loss Reduction
reduction Tensor
self Tensor
target = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.smooth_l1_loss_ttl Tensor
self Tensor
target Reduction
reduction

-- | Creates a criterion that optimizes a two-class classification logistic loss
--  between input tensor \(x\) and target tensor \(y\) (containing 1 or -1).
softMarginLoss ::
  -- | reduction
  Reduction ->
  -- | input
  Tensor ->
  -- | target
  Tensor ->
  -- | output
  Tensor
softMarginLoss :: Reduction -> Tensor -> Tensor -> Tensor
softMarginLoss Reduction
reduction Tensor
input Tensor
target = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.soft_margin_loss_ttl Tensor
input Tensor
target Reduction
reduction

--
-- Pooling
--

-- | Applies a 1D adaptive max pooling over an input signal composed of several input planes.
adaptiveMaxPool1d ::
  -- | output size
  Int ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor)
adaptiveMaxPool1d :: Int -> Tensor -> (Tensor, Tensor)
adaptiveMaxPool1d Int
outputSize Tensor
self =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.adaptive_max_pool1d_tl
      Tensor
self
      Int
outputSize

-- | Applies a 2D adaptive max pooling over an input signal composed of several input planes.
adaptiveMaxPool2d ::
  -- | output size
  (Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor)
adaptiveMaxPool2d :: (Int, Int) -> Tensor -> (Tensor, Tensor)
adaptiveMaxPool2d (Int, Int)
outputSize Tensor
self =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.adaptive_max_pool2d_tl
      Tensor
self
      (Int, Int)
outputSize

-- | Applies a 3D adaptive max pooling over an input signal composed of several input planes
adaptiveMaxPool3d ::
  -- | output size
  (Int, Int) ->
  -- | input
  Tensor ->
  (Tensor, Tensor)
adaptiveMaxPool3d :: (Int, Int) -> Tensor -> (Tensor, Tensor)
adaptiveMaxPool3d (Int, Int)
outputSize Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.adaptive_max_pool3d_tl Tensor
input (Int, Int)
outputSize

-- | maxPool1dWithIndices
maxPool1dWithIndices ::
  -- | kernel size
  Int ->
  -- | stride
  Int ->
  -- | padding
  Int ->
  -- | dilation
  Int ->
  -- | ceil mode
  CeilMode ->
  -- | input
  Tensor ->
  -- | output, indices
  (Tensor, Tensor)
maxPool1dWithIndices :: Int -> Int -> Int -> Int -> CeilMode -> Tensor -> (Tensor, Tensor)
maxPool1dWithIndices Int
kernelSize Int
stride Int
padding Int
dilation CeilMode
ceilMode Tensor
self =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.max_pool1d_with_indices_tllllb
      Tensor
self
      Int
kernelSize
      Int
stride
      Int
padding
      Int
dilation
      CeilMode
ceilMode

-- | Applies a 1D max pooling over an input signal composed of several input planes.
maxPool1d ::
  -- | kernel size
  Int ->
  -- | stride
  Int ->
  -- | padding
  Int ->
  -- | dilation
  Int ->
  -- | ceil mode
  CeilMode ->
  -- | input
  Tensor ->
  -- | output
  Tensor
maxPool1d :: Int -> Int -> Int -> Int -> CeilMode -> Tensor -> Tensor
maxPool1d Int
kernelSize Int
stride Int
padding Int
dilation CeilMode
ceilMode Tensor
self =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.max_pool1d_tllllb
      Tensor
self
      Int
kernelSize
      Int
stride
      Int
padding
      Int
dilation
      CeilMode
ceilMode

-- | Applies a 2D max pooling over an input signal composed of several input planes.
maxPool2d ::
  -- | kernel size
  (Int, Int) ->
  -- | stride
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | dilation
  (Int, Int) ->
  -- | ceil mode
  CeilMode ->
  -- | input
  Tensor ->
  -- | output
  Tensor
maxPool2d :: (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> CeilMode
-> Tensor
-> Tensor
maxPool2d (Int, Int)
kernelSize (Int, Int)
stride (Int, Int)
padding (Int, Int)
dilation CeilMode
ceilMode Tensor
self =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.max_pool2d_tllllb
      Tensor
self
      ((Int, Int) -> [Int]
asList (Int, Int)
kernelSize)
      ((Int, Int) -> [Int]
asList (Int, Int)
stride)
      ((Int, Int) -> [Int]
asList (Int, Int)
padding)
      ((Int, Int) -> [Int]
asList (Int, Int)
dilation)
      CeilMode
ceilMode
  where
    asList :: (Int, Int) -> [Int]
    asList :: (Int, Int) -> [Int]
asList (Int
a0, Int
a1) = [Int
a0, Int
a1]

-- | Applies a 3D max pooling over an input signal composed of several input planes.
maxPool3d ::
  -- | kernel size
  (Int, Int, Int) ->
  -- | stride
  (Int, Int, Int) ->
  -- | padding
  (Int, Int, Int) ->
  -- | dilation
  (Int, Int, Int) ->
  -- | ceil mode
  CeilMode ->
  -- | input
  Tensor ->
  -- | output
  Tensor
maxPool3d :: (Int, Int, Int)
-> (Int, Int, Int)
-> (Int, Int, Int)
-> (Int, Int, Int)
-> CeilMode
-> Tensor
-> Tensor
maxPool3d (Int, Int, Int)
kernelSize (Int, Int, Int)
stride (Int, Int, Int)
padding (Int, Int, Int)
dilation CeilMode
ceilMode Tensor
self =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.max_pool3d_tllllb
      Tensor
self
      (Int, Int, Int)
kernelSize
      (Int, Int, Int)
stride
      (Int, Int, Int)
padding
      (Int, Int, Int)
dilation
      CeilMode
ceilMode

-- | Calculates resulting dimensions from a 2d maxpool operation
-- see https://pytorch.org/docs/master/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
maxPool2dDim ::
  -- | kernel size
  (Int, Int) ->
  -- | stride
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | dilation
  (Int, Int) ->
  -- | Ceiling or Floor
  CeilMode ->
  -- | image dimensions
  (Int, Int) ->
  -- | height, width after maxPool
  (Int, Int)
maxPool2dDim :: (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> CeilMode
-> (Int, Int)
-> (Int, Int)
maxPool2dDim (Int, Int)
kernelSize (Int, Int)
stride (Int, Int)
padding (Int, Int)
dilation CeilMode
ceilMode (Int, Int)
imgDim =
  (forall {b} {b}. (Integral b, Integral b) => ((Int, Int) -> b) -> b
calc forall a b. (a, b) -> a
fst, forall {b} {b}. (Integral b, Integral b) => ((Int, Int) -> b) -> b
calc forall a b. (a, b) -> b
snd)
  where
    trunc :: CeilMode -> a -> b
trunc CeilMode
Ceil = forall a b. (RealFrac a, Integral b) => a -> b
P.ceiling
    trunc CeilMode
Floor = forall a b. (RealFrac a, Integral b) => a -> b
P.floor
    calc :: ((Int, Int) -> b) -> b
calc (Int, Int) -> b
f' =
      let f :: (Int, Int) -> Float
f = (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Int) -> b
f' :: (Int, Int) -> Float)
       in forall {a} {b}. (RealFrac a, Integral b) => CeilMode -> a -> b
trunc CeilMode
ceilMode forall a b. (a -> b) -> a -> b
$
            ( (Int, Int) -> Float
f (Int, Int)
imgDim
                forall a. Num a => a -> a -> a
+ Float
2 forall a. Num a => a -> a -> a
* (Int, Int) -> Float
f (Int, Int)
padding
                forall a. Num a => a -> a -> a
- (Int, Int) -> Float
f (Int, Int)
dilation forall a. Num a => a -> a -> a
* ((Int, Int) -> Float
f (Int, Int)
kernelSize forall a. Num a => a -> a -> a
- Float
1)
                forall a. Num a => a -> a -> a
- Float
1
            )
              forall a. Fractional a => a -> a -> a
/ (Int, Int) -> Float
f (Int, Int)
stride
              forall a. Num a => a -> a -> a
+ Float
1

-- | Applies a 1D average pooling over an input signal composed of several input planes.
avgPool1d ::
  -- | kernel size
  Int ->
  -- | stride
  Int ->
  -- | padding
  Int ->
  -- | ceil mode
  CeilMode ->
  -- | count include pad
  Bool ->
  -- | input
  Tensor ->
  -- | output
  Tensor
avgPool1d :: Int -> Int -> Int -> CeilMode -> Bool -> Tensor -> Tensor
avgPool1d Int
kernelSize Int
stride Int
padding CeilMode
ceilMode Bool
countIncludePad Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.avg_pool1d_tlllbb
      Tensor
input
      Int
kernelSize
      Int
stride
      Int
padding
      CeilMode
ceilMode
      Bool
countIncludePad

avgPool1d' ::
  -- | kernel size
  Int ->
  -- | stride
  Int ->
  -- | padding
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
avgPool1d' :: Int -> Int -> Int -> Tensor -> Tensor
avgPool1d' Int
kernelSize Int
stride Int
padding = Int -> Int -> Int -> CeilMode -> Bool -> Tensor -> Tensor
avgPool1d Int
kernelSize Int
stride Int
padding CeilMode
Floor Bool
True

-- | Applies a 1D adaptive average pooling over an input signal composed of several input planes.
adaptiveAvgPool1d ::
  Int -> -- outputSize

  -- | input
  Tensor ->
  -- | output
  Tensor
adaptiveAvgPool1d :: Int -> Tensor -> Tensor
adaptiveAvgPool1d Int
outputSize Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.adaptive_avg_pool1d_tl Tensor
input Int
outputSize

-- | Applies a 2D adaptive average pooling over an input signal composed of several input planes.
adaptiveAvgPool2d ::
  -- | output size (Height * Width)
  (Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
adaptiveAvgPool2d :: (Int, Int) -> Tensor -> Tensor
adaptiveAvgPool2d (Int
outputHeight, Int
outputWidth) Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2
      ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.adaptive_avg_pool2d_tl
      Tensor
input
      ([Int
outputHeight, Int
outputWidth] :: [Int])

-- | Applies a 3D adaptive average pooling over an input signal composed of several input planes.
adaptiveAvgPool3d ::
  -- | output size (Depth * Height * Width)
  (Int, Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
adaptiveAvgPool3d :: (Int, Int, Int) -> Tensor -> Tensor
adaptiveAvgPool3d (Int, Int, Int)
_output_size Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.adaptive_avg_pool3d_tl Tensor
_self (Int, Int, Int)
_output_size

--
-- matrix solvers
--

-- | Takes the inverse of the square matrix input. @input@ can be batches of 2D square tensors, in which case this function would return a tensor composed of individual inverses.
inverse ::
  -- | input
  Tensor ->
  -- | output
  Tensor
inverse :: Tensor -> Tensor
inverse Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.inverse_t Tensor
t

-- | Solves a system of equations with a triangular coefficient matrix AA and multiple right-hand sides bb
triangularSolve ::
  -- | A
  Tensor ->
  -- | upper
  Bool ->
  -- | transpose
  Bool ->
  -- | unitriangular
  Bool ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor)
triangularSolve :: Tensor -> Bool -> Bool -> Bool -> Tensor -> (Tensor, Tensor)
triangularSolve Tensor
_A Bool
_upper Bool
_transpose Bool
_unitriangular Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.triangular_solve_ttbbb Tensor
_self Tensor
_A Bool
_upper Bool
_transpose Bool
_unitriangular

-- | This function returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, represented by a namedtuple (eigenvalues, eigenvectors).
symeig ::
  -- | bool which controls whether eigenvectors have to be computed
  Bool ->
  -- | controls whether to consider upper-triangular or lower-triangular region
  Tri ->
  -- | input tensor
  Tensor ->
  -- | output tensors
  (Tensor, Tensor)
symeig :: Bool -> Tri -> Tensor -> (Tensor, Tensor)
symeig Bool
eigenvectors Tri
upper Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> CBool -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.symeig_tbb Tensor
t Bool
eigenvectors Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | Computes the eigenvalues and eigenvectors of a real square matrix
eig ::
  -- | bool to compute both eigenvalues and eigenvectors; otherwise, only eigenvalues will be computed
  Bool ->
  -- | input (square matrix) for which the eigen values and eigen vectors are to be computed
  Tensor ->
  -- | output tensors
  (Tensor, Tensor)
eig :: Bool -> Tensor -> (Tensor, Tensor)
eig Bool
eigenvectors Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.eig_tb Tensor
t Bool
eigenvectors

-- | This function returns a namedtuple (U, S, V) which is the singular value decomposition of a input real matrix or batches of real matrices input such that input = U * diag(S) * V^T
svd ::
  -- | controls the shape of returned U and V
  Bool ->
  -- | option whether to compute U and V or not
  Bool ->
  -- | input
  Tensor ->
  -- | output tuple of tensors
  (Tensor, Tensor, Tensor)
svd :: Bool -> Bool -> Tensor -> (Tensor, Tensor, Tensor)
svd Bool
some Bool
compute_uv Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.svd_tbb Tensor
t Bool
some Bool
compute_uv

-- | Computes the Cholesky decomposition of a symmetric positive-definite matrix AA or for batches of symmetric positive-definite matrices.
cholesky ::
  -- | flag that indicates whether to return a upper or lower triangular matrix.
  Tri ->
  -- | input
  Tensor ->
  -- | output
  Tensor
cholesky :: Tri -> Tensor -> Tensor
cholesky Tri
upper Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.cholesky_tb Tensor
t Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | Solves a linear system of equations with a positive semidefinite matrix to be inverted given its Cholesky factor matrix uu .
choleskySolve ::
  -- | bool whether to consider the Cholesky factor as a lower or upper triangular matrix
  Tri ->
  -- | input matrix b
  Tensor ->
  -- | input matrix u
  Tensor ->
  -- | output
  Tensor
choleskySolve :: Tri -> Tensor -> Tensor -> Tensor
choleskySolve Tri
upper Tensor
t1 Tensor
t2 = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.cholesky_solve_ttb Tensor
t1 Tensor
t2 Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- | This function returns the solution to the system of linear equations represented by AX = BAX=B and the LU factorization of A, in order as a namedtuple solution, LU.
-- LU contains L and U factors for LU factorization of A
solve ::
  -- | input matrix
  Tensor ->
  -- | input square matrix
  Tensor ->
  -- | output tuple with solution and LU
  (Tensor, Tensor)
solve :: Tensor -> Tensor -> (Tensor, Tensor)
solve Tensor
b Tensor
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.solve_tt Tensor
b Tensor
a

-- | Solves a linear system of equations with a positive semidefinite matrix to be inverted given its Cholesky factor matrix uu .
choleskyInverse ::
  -- | upper or lower triangle
  Tri ->
  -- | input
  Tensor ->
  -- | solution
  Tensor
choleskyInverse :: Tri -> Tensor -> Tensor
choleskyInverse Tri
upper Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.cholesky_inverse_tb Tensor
t Bool
boolUpper
  where
    boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper

-- pstrf :: Bool -> Double -> Tensor -> (Tensor, Tensor)
-- pstrf upper tol t = unsafePerformIO $ cast3 ATen.pstrf_tbs t upper tol

-- qr :: Tensor -> (Tensor, Tensor)
-- qr t = unsafePerformIO $ cast1 ATen.qr_t t

-- | This is a low-level function for calling LAPACK directly. This function returns a namedtuple (a, tau) as defined in LAPACK documentation for geqrf.
geqrf ::
  -- | input
  Tensor ->
  -- | a, tau output matrices (see https://software.intel.com/en-us/node/521004)
  (Tensor, Tensor)
geqrf :: Tensor -> (Tensor, Tensor)
geqrf Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.geqrf_t Tensor
t

-- | Computes the orthogonal matrix Q of a QR factorization, from the @(input, input2)@ tuple returned by 'geqrf' function.
-- This directly calls the underlying LAPACK function @?orgqr@. See LAPACK documentation for @orgqr@ for further details.
orgqr ::
  -- | the @a@ from @geqrf@ function
  Tensor ->
  -- | the @tau@ from @geqrf@ function
  Tensor ->
  -- | output
  Tensor
orgqr :: Tensor -> Tensor -> Tensor
orgqr Tensor
b Tensor
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.orgqr_tt Tensor
b Tensor
a

-- | Multiplies mat (given by input3) by the orthogonal Q matrix of the QR factorization formed by torch.geqrf() that is represented by (a, tau) (given by (input, input2)).
-- This directly calls the underlying LAPACK function ?ormqr. See LAPACK documentation for ormqr for further details.
ormqr ::
  -- | input2
  Tensor ->
  -- | input3
  Tensor ->
  -- | left
  Bool ->
  -- | transpose
  Bool ->
  -- | input
  Tensor ->
  -- | output
  Tensor
ormqr :: Tensor -> Tensor -> Bool -> Bool -> Tensor -> Tensor
ormqr Tensor
_input2 Tensor
_input3 Bool
_left Bool
_transpose Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.ormqr_tttbb Tensor
_self Tensor
_input2 Tensor
_input3 Bool
_left Bool
_transpose

-- | Returns the LU solve of the linear system Ax = bAx=b using the partially pivoted LU factorization of A from torch.lu().
luSolve ::
  -- | LU_data
  Tensor ->
  -- | LU_pivots
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
luSolve :: Tensor -> Tensor -> Tensor -> Tensor
luSolve Tensor
_LU_data Tensor
_LU_pivots Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.lu_solve_ttt Tensor
_self Tensor
_LU_data Tensor
_LU_pivots

--
-- dropout
--

-- | During training, randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.
dropout ::
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor ->
  -- | output
  IO Tensor
dropout :: Double -> Bool -> Tensor -> IO Tensor
dropout Double
p Bool
train Tensor
input = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.dropout_tdb Tensor
input Double
p Bool
train

featureDropout ::
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor ->
  -- | output
  IO Tensor
featureDropout :: Double -> Bool -> Tensor -> IO Tensor
featureDropout Double
p Bool
train Tensor
input =
  forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.feature_dropout_tdb Tensor
input Double
p Bool
train

-- | Applies alpha dropout to the input.
alphaDropout ::
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor ->
  -- | output
  IO Tensor
alphaDropout :: Double -> Bool -> Tensor -> IO Tensor
alphaDropout Double
p Bool
train Tensor
input =
  forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.alpha_dropout_tdb Tensor
input Double
p Bool
train

featureAlphaDropout ::
  -- | dropout probability
  Double ->
  -- | whether or not to activate dropout
  Bool ->
  -- | input
  Tensor ->
  -- | output
  IO Tensor
featureAlphaDropout :: Double -> Bool -> Tensor -> IO Tensor
featureAlphaDropout Double
p Bool
train Tensor
input =
  forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.feature_alpha_dropout_tdb Tensor
input Double
p Bool
train

--
-- Element-wise logical operators
--

-- | Computes the bitwise NOT of the given input tensor. The input tensor must be of integral or Boolean types. For bool tensors, it computes the logical NOT.
bitwiseNot ::
  -- | input
  Tensor ->
  -- | output
  Tensor
bitwiseNot :: Tensor -> Tensor
bitwiseNot Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.bitwise_not_t Tensor
input

-- | Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool dtype. If the input tensor is not a bool tensor, zeros are treated as False and non-zeros are treated as True.
logicalNot ::
  -- | input
  Tensor ->
  -- | output
  Tensor
logicalNot :: Tensor -> Tensor
logicalNot Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.logical_not_t Tensor
t

logicalXor ::
  -- | self
  Tensor ->
  -- | other
  Tensor ->
  Tensor
logicalXor :: Tensor -> Tensor -> Tensor
logicalXor Tensor
self Tensor
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.logical_xor_tt Tensor
self Tensor
other

logicalAnd ::
  -- | self
  Tensor ->
  -- | other
  Tensor ->
  Tensor
logicalAnd :: Tensor -> Tensor -> Tensor
logicalAnd Tensor
self Tensor
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.logical_and_tt Tensor
self Tensor
other

logicalOr ::
  -- | self
  Tensor ->
  -- | other
  Tensor ->
  Tensor
logicalOr :: Tensor -> Tensor -> Tensor
logicalOr Tensor
self Tensor
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.logical_or_tt Tensor
self Tensor
other

-- | Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
cat ::
  -- | dimension
  Dim ->
  -- | list of tensors to concatenate
  [Tensor] ->
  -- | output tensor
  Tensor
cat :: Dim -> [Tensor] -> Tensor
cat (Dim Int
d) [Tensor]
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.cat_ll [Tensor]
tensors Int
d

index ::
  -- | indices
  [Tensor] ->
  -- | input
  Tensor ->
  -- | output
  Tensor
index :: [Tensor] -> Tensor -> Tensor
index [Tensor]
_indices Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr (C10List (C10Optional Tensor))
-> IO (ForeignPtr Tensor)
ATen.index_tl Tensor
_self [Tensor]
_indices

-- Copies the elements of tensor into the self tensor (out-of-place) by selecting the indices in the order given in index.
-- For example, if dim == 0 and index[i] == j, then the ith row of tensor is copied to the jth row of self.
-- The dimth dimension of tensor must have the same size as the length of index (which must be a vector), and all other dimensions must match self, or an error will be raised.
indexCopy ::
  -- | dim
  Int ->
  -- | index
  Tensor ->
  -- | source
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
indexCopy :: Int -> Tensor -> Tensor -> Tensor -> Tensor
indexCopy Int
dim Tensor
index Tensor
source Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> Int64
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.index_copy_tltt Tensor
t Int
dim Tensor
index Tensor
source

indexCopyWithDimname ::
  -- | dim
  Dimname ->
  -- | index
  Tensor ->
  -- | source
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
indexCopyWithDimname :: Dimname -> Tensor -> Tensor -> Tensor -> Tensor
indexCopyWithDimname Dimname
dim Tensor
index Tensor
source Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Dimname
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.index_copy_tntt Tensor
t Dimname
dim Tensor
index Tensor
source

-- | Puts values from the tensor value into the input tensor (out-of-place)
-- using the indices specified in indices (which is a tuple of Tensors).
-- The expression tensor.index_put_(indices, value) is equivalent to tensor[indices] = value.
-- If accumulate is True, the elements in value are added to self. If accumulate is False, the behavior is undefined if indices contain duplicate elements.
indexPut ::
  -- | accumulate
  Bool ->
  -- | indices
  [Tensor] ->
  -- | values
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
indexPut :: Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
indexPut Bool
accumulate [Tensor]
indices Tensor
values Tensor
self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr (C10List (C10Optional Tensor))
-> ForeignPtr Tensor
-> CBool
-> IO (ForeignPtr Tensor)
ATen.index_put_tltb Tensor
self [Tensor]
indices Tensor
values Bool
accumulate

-- | Splits a tensor into a specific number of chunks.
-- Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
chunk ::
  -- | chunks
  Int ->
  -- | dim
  Dim ->
  -- | input tensor
  Tensor ->
  -- | output list of tensors
  [Tensor]
chunk :: Int -> Dim -> Tensor -> [Tensor]
chunk Int
chunks (Dim Int
d) Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr TensorList)
ATen.chunk_tll Tensor
input Int
chunks Int
d

-- | Clamp all elements in input into the range [ min, max ] and return a resulting tensor.
clamp ::
  -- | minimum value
  Float ->
  -- | maximum value
  Float ->
  -- | input
  Tensor ->
  -- | output
  Tensor
clamp :: Float -> Float -> Tensor -> Tensor
clamp Float
min Float
max Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.clamp_tss Tensor
input Float
min Float
max

-- | Clamps all elements in input to be smaller or equal max.
clampMax ::
  -- | maximum value
  Float ->
  -- | input
  Tensor ->
  -- | output
  Tensor
clampMax :: Float -> Tensor -> Tensor
clampMax Float
max Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.clamp_max_ts Tensor
input Float
max

-- | Clamps all elements in input to be larger or equal min.
clampMin ::
  -- | minimum value
  Float ->
  -- | input
  Tensor ->
  -- | output
  Tensor
clampMin :: Float -> Tensor -> Tensor
clampMin Float
min Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.clamp_min_ts Tensor
input Float
min

cudnnIsAcceptable ::
  -- | input
  Tensor ->
  -- | output
  Bool
cudnnIsAcceptable :: Tensor -> Bool
cudnnIsAcceptable Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.cudnn_is_acceptable_t Tensor
input

-- | Pads the input tensor boundaries with a constant value.
constantPadNd1d ::
  -- | list of padding per dimension
  [Int] ->
  -- | value
  Float ->
  -- | input
  Tensor ->
  -- | ouptut
  Tensor
constantPadNd1d :: [Int] -> Float -> Tensor -> Tensor
constantPadNd1d [Int]
padding Float
value Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3
      ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.constant_pad_nd_tls
      Tensor
input
      [Int]
padding
      Float
value

--
-- convolutions
--

-- | Applies a 1D convolution over an input signal composed of several input planes.
conv1d ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | stride
  Int ->
  -- | padding
  Int ->
  -- | dilation
  Int ->
  -- | groups
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv1d :: Tensor -> Tensor -> Int -> Int -> Int -> Int -> Tensor -> Tensor
conv1d Tensor
weight Tensor
bias Int
stride Int
padding Int
dilation Int
groups Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.conv1d_tttllll
      Tensor
input
      Tensor
weight
      Tensor
bias
      Int
stride
      Int
padding
      Int
dilation
      Int
groups

conv1d' ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  Int ->
  -- | padding
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv1d' :: Tensor -> Tensor -> Int -> Int -> Tensor -> Tensor
conv1d' Tensor
weight Tensor
bias Int
stride Int
padding = Tensor -> Tensor -> Int -> Int -> Int -> Int -> Tensor -> Tensor
conv1d Tensor
weight Tensor
bias Int
stride Int
padding Int
1 Int
1

-- | Applies a 2D convolution over an input signal composed of several input planes.
conv2d ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | dilation
  (Int, Int) ->
  -- | groups
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv2d :: Tensor
-> Tensor
-> (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> Int
-> Tensor
-> Tensor
conv2d Tensor
weight Tensor
bias (Int
stride0, Int
stride1) (Int
padding0, Int
padding1) (Int
dilation0, Int
dilation1) Int
groups Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.conv2d_tttllll
      Tensor
input
      Tensor
weight
      Tensor
bias
      ([Int
stride0, Int
stride1] :: [Int])
      ([Int
padding0, Int
padding1] :: [Int])
      ([Int
dilation0, Int
dilation1] :: [Int])
      Int
groups

conv2d' ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv2d' :: Tensor -> Tensor -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
conv2d' Tensor
weight Tensor
bias (Int, Int)
stride (Int, Int)
padding =
  Tensor
-> Tensor
-> (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> Int
-> Tensor
-> Tensor
conv2d
    Tensor
weight
    Tensor
bias
    (Int, Int)
stride
    (Int, Int)
padding
    (Int
1, Int
1) -- dilation
    (Int
1 :: Int) -- groups

-- | Applies a 3D convolution over an input signal composed of several input planes.
conv3d ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int, Int) ->
  -- | padding
  (Int, Int, Int) ->
  -- | dilation
  (Int, Int, Int) ->
  -- | groups
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv3d :: Tensor
-> Tensor
-> (Int, Int, Int)
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Int
-> Tensor
-> Tensor
conv3d Tensor
weight Tensor
bias (Int
stride0, Int
stride1, Int
stride2) (Int
padding0, Int
padding1, Int
padding2) (Int
dilation0, Int
dilation1, Int
dilation2) Int
groups Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.conv3d_tttllll
      Tensor
input
      Tensor
weight
      Tensor
bias
      ([Int
stride0, Int
stride1, Int
stride2] :: [Int])
      ([Int
padding0, Int
padding1, Int
padding2] :: [Int])
      ([Int
dilation0, Int
dilation1, Int
dilation2] :: [Int])
      Int
groups

conv3d' ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int, Int) ->
  -- | padding
  (Int, Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv3d' :: Tensor
-> Tensor -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
conv3d' Tensor
weight Tensor
bias (Int, Int, Int)
stride (Int, Int, Int)
padding =
  Tensor
-> Tensor
-> (Int, Int, Int)
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Int
-> Tensor
-> Tensor
conv3d
    Tensor
weight
    Tensor
bias
    (Int, Int, Int)
stride
    (Int, Int, Int)
padding
    (Int
1, Int
1, Int
1) -- dilation
    (Int
1 :: Int) -- groups

-- | Applies a 1D transposed convolution over an input signal composed of several input planes
convTranspose1d ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  Int ->
  -- | padding
  Int ->
  -- | output padding
  Int ->
  -- | groups
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose1d :: Tensor -> Tensor -> Int -> Int -> Int -> Int -> Tensor -> Tensor
convTranspose1d Tensor
weight Tensor
bias Int
stride Int
padding Int
outPadding Int
groups Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.conv_transpose1d_tttllll
      Tensor
input
      Tensor
weight
      Tensor
bias
      (Int
stride :: Int)
      (Int
padding :: Int)
      (Int
outPadding :: Int)
      Int
groups

convTranspose1d' ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  Int ->
  -- | padding
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose1d' :: Tensor -> Tensor -> Int -> Int -> Tensor -> Tensor
convTranspose1d' Tensor
weight Tensor
bias Int
stride Int
padding =
  Tensor -> Tensor -> Int -> Int -> Int -> Int -> Tensor -> Tensor
convTranspose1d
    Tensor
weight
    Tensor
bias
    Int
stride
    Int
padding
    Int
0
    (Int
1 :: Int)

-- | Applies a 2D transposed convolution over an input signal composed of several input planes
convTranspose2d ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | output padding
  (Int, Int) ->
  -- | groups
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose2d :: Tensor
-> Tensor
-> (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> Int
-> Tensor
-> Tensor
convTranspose2d Tensor
weight Tensor
bias (Int
stride0, Int
stride1) (Int
padding0, Int
padding1) (Int
outPadding0, Int
outPadding1) Int
groups Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.conv_transpose2d_tttllll
      Tensor
input
      Tensor
weight
      Tensor
bias
      ([Int
stride0, Int
stride1] :: [Int])
      ([Int
padding0, Int
padding1] :: [Int])
      ([Int
outPadding0, Int
outPadding1] :: [Int])
      Int
groups

convTranspose2d' ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose2d' :: Tensor -> Tensor -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
convTranspose2d' Tensor
weight Tensor
bias (Int, Int)
stride (Int, Int)
padding =
  Tensor
-> Tensor
-> (Int, Int)
-> (Int, Int)
-> (Int, Int)
-> Int
-> Tensor
-> Tensor
convTranspose2d
    Tensor
weight
    Tensor
bias
    (Int, Int)
stride
    (Int, Int)
padding
    (Int
0, Int
0)
    (Int
1 :: Int)

-- | Applies a 3D transposed convolution over an input signal composed of several input planes
convTranspose3d ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int, Int) ->
  -- | padding
  (Int, Int, Int) ->
  -- | output padding
  (Int, Int, Int) ->
  -- | groups
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose3d :: Tensor
-> Tensor
-> (Int, Int, Int)
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Int
-> Tensor
-> Tensor
convTranspose3d Tensor
weight Tensor
bias (Int
stride0, Int
stride1, Int
stride2) (Int
padding0, Int
padding1, Int
padding2) (Int
outPadding0, Int
outPadding1, Int
outPadding2) Int
groups Tensor
input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.conv_transpose3d_tttllll
      Tensor
input
      Tensor
weight
      Tensor
bias
      ([Int
stride0, Int
stride1, Int
stride2] :: [Int])
      ([Int
padding0, Int
padding1, Int
padding2] :: [Int])
      ([Int
outPadding0, Int
outPadding1, Int
outPadding2] :: [Int])
      Int
groups

convTranspose3d' ::
  -- | weight
  Tensor ->
  -- | bias
  Tensor ->
  -- | strides
  (Int, Int, Int) ->
  -- | padding
  (Int, Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose3d' :: Tensor
-> Tensor -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
convTranspose3d' Tensor
weight Tensor
bias (Int, Int, Int)
stride (Int, Int, Int)
padding =
  Tensor
-> Tensor
-> (Int, Int, Int)
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Int
-> Tensor
-> Tensor
convTranspose3d
    Tensor
weight
    Tensor
bias
    (Int, Int, Int)
stride
    (Int, Int, Int)
padding
    (Int
0, Int
0, Int
0)
    (Int
1 :: Int)

-- | Returns a new tensor with the signs of the elements of @input@
sign ::
  -- | input
  Tensor ->
  -- | output
  Tensor
sign :: Tensor -> Tensor
sign Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sign_t Tensor
t

-- | Returns a tensor that is a transposed version of @input@. The given dimensions @dim0@ and @dim1@ are swapped.
transpose ::
  -- | dim1
  Dim ->
  -- | dim2
  Dim ->
  -- | input
  Tensor ->
  -- | output
  Tensor
transpose :: Dim -> Dim -> Tensor -> Tensor
transpose (Dim Int
d1) (Dim Int
d2) Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.transpose_tll Tensor
t Int
d1 Int
d2

-- | transpose special case for a 2D tensor
transpose2D ::
  -- | input
  Tensor ->
  -- | output
  Tensor
transpose2D :: Tensor -> Tensor
transpose2D = Dim -> Dim -> Tensor -> Tensor
transpose (Int -> Dim
Dim Int
0) (Int -> Dim
Dim Int
1)

-- | Returns a tensor with the elements of input as the diagonal.
-- The second argument controls which diagonal to consider:
--        If Int = 0, it is the main diagonal.
--        If Int > 0, it is above the main diagonal.
--        If Int < 0, it is below the main diagonal.
diag ::
  -- | diagonal
  Diag ->
  -- | input
  Tensor ->
  -- | output
  Tensor
diag :: Diag -> Tensor -> Tensor
diag (Diag Int
index) Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_diag_l Tensor
t Int
index

--
diagEmbed ::
  -- | offset
  Diag ->
  -- | dim1
  Dim ->
  -- | dim2
  Dim ->
  -- | self
  Tensor ->
  Tensor
diagEmbed :: Diag -> Dim -> Dim -> Tensor -> Tensor
diagEmbed (Diag Int
offset) (Dim Int
dim1) (Dim Int
dim2) Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.diag_embed_tlll Tensor
t Int
offset Int
dim1 Int
dim2

-- | If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal.
-- If input is a tensor with more than one dimension, then returns a 2-D tensor with diagonal elements equal to a flattened input.
-- The argument offset controls which diagonal to consider:
--  If offset = 0, it is the main diagonal.
--  If offset > 0, it is above the main diagonal.
--  If offset < 0, it is below the main diagonal.
diagflat ::
  -- | offset
  Diag ->
  -- | self
  Tensor ->
  -- | output
  Tensor
diagflat :: Diag -> Tensor -> Tensor
diagflat (Diag Int
offset) Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.diagflat_tl Tensor
t Int
offset

-- | Returns a partial view of input with the its diagonal elements with respect to dim1 and dim2 appended as a dimension at the end of the shape.
-- Applying diagEmbed to the output of this function with the same arguments yields a diagonal matrix with the diagonal entries of the input. However, diagEmbed has different default dimensions, so those need to be explicitly specified.
diagonal ::
  -- | offset
  Diag ->
  -- | dim1
  Dim ->
  -- | dim2
  Dim ->
  -- | input
  Tensor ->
  -- | output
  Tensor
diagonal :: Diag -> Dim -> Dim -> Tensor -> Tensor
diagonal (Diag Int
offset) (Dim Int
dim1) (Dim Int
dim2) Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.diagonal_tlll Tensor
t Int
offset Int
dim1 Int
dim2

-- | Returns True if all elements in the tensor are True, False otherwise.
all ::
  -- | input
  Tensor ->
  -- | output
  Bool
all :: Tensor -> Bool
all Tensor
t = Tensor -> Int
toInt (forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.all_t Tensor
t) forall a. Eq a => a -> a -> Bool
== Int
1

-- | Returns True if any elements in the tensor are True, False otherwise.
any ::
  -- | input
  Tensor ->
  -- | output
  Bool
any :: Tensor -> Bool
any Tensor
t = Tensor -> Int
toInt (forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.any_t Tensor
t) forall a. Eq a => a -> a -> Bool
== Int
1

-- | Returns True if all elements in each row of the tensor in the given dimension dim are True, False otherwise.
-- If keepdim is True, the output tensor is of the same size as input except in the dimension dim where it is of size 1. Otherwise, dim is squeezed, resulting in the output tensor having 1 fewer dimension than input.
allDim ::
  -- | dimension
  Dim ->
  -- | boolean corresponding to keepdim
  Bool ->
  -- | input
  Tensor ->
  -- | output
  Tensor
allDim :: Dim -> Bool -> Tensor -> Tensor
allDim (Dim Int
d) Bool
keepdim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.all_tlb Tensor
t Int
d Bool
keepdim

-- | Returns True if any elements in each row of the tensor in the given dimension dim are True, False otherwise.
-- If keepdim is True, the output tensor is of the same size as input except in the dimension dim where it is of size 1. Otherwise, dim is squeezed, resulting in the output tensor having 1 fewer dimension than input.
anyDim ::
  -- | dimension
  Dim ->
  -- | boolean corresponding to keepdim
  Bool ->
  -- | input
  Tensor ->
  Tensor -- output
anyDim :: Dim -> Bool -> Tensor -> Tensor
anyDim (Dim Int
d) Bool
keepdim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.any_tlb Tensor
t Int
d Bool
keepdim

-- | Permute the dimensions of this tensor.
permute ::
  -- | list corresponding to ordering of dimensions to permute with
  [Int] ->
  -- | input
  Tensor ->
  Tensor -- output
permute :: [Int] -> Tensor -> Tensor
permute [Int]
dims Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.tensor_permute_l Tensor
t [Int]
dims

-- | expand
-- TODO: figure out what the `implicit` boolean value does
expand ::
  -- | input
  Tensor ->
  -- | some boolean value with unknown function
  Bool ->
  -- | the desired expanded size
  [Int] ->
  -- | output
  Tensor
expand :: Tensor -> Bool -> [Int] -> Tensor
expand Tensor
t Bool
someBool [Int]
dims = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_expand_lb Tensor
t [Int]
dims Bool
someBool

-- | flatten
flatten ::
  -- | startDim
  Dim ->
  -- | endDim
  Dim ->
  -- | self
  Tensor ->
  -- | output
  Tensor
flatten :: Dim -> Dim -> Tensor -> Tensor
flatten (Dim Int
startDim) (Dim Int
endDim) Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.flatten_tll Tensor
t Int
startDim Int
endDim

-- | flattenAll
flattenAll ::
  -- | input
  Tensor ->
  -- | output
  Tensor
flattenAll :: Tensor -> Tensor
flattenAll Tensor
t =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.flatten_tll Tensor
t (Int
0 :: Int) (-Int
1 :: Int)

lstm ::
  -- | input
  Tensor ->
  -- | hx
  [Tensor] ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  -- | batch_first
  Bool ->
  (Tensor, Tensor, Tensor)
lstm :: Tensor
-> [Tensor]
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> Bool
-> (Tensor, Tensor, Tensor)
lstm Tensor
_input [Tensor]
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.lstm_tllbldbbb Tensor
_input [Tensor]
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first

lstm' ::
  -- | batch_sizes
  Tensor ->
  -- | hx
  [Tensor] ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  -- | data
  Tensor ->
  (Tensor, Tensor, Tensor)
lstm' :: Tensor
-> [Tensor]
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> Tensor
-> (Tensor, Tensor, Tensor)
lstm' Tensor
_batch_sizes [Tensor]
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Tensor
_data = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.lstm_ttllbldbb Tensor
_data Tensor
_batch_sizes [Tensor]
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional

gru ::
  -- | hx
  Tensor ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  -- | batch_first
  Bool ->
  -- | input
  Tensor ->
  (Tensor, Tensor)
gru :: Tensor
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> Bool
-> Tensor
-> (Tensor, Tensor)
gru Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first Tensor
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.gru_ttlbldbbb Tensor
_input Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first

gru' ::
  -- | batch_sizes
  Tensor ->
  -- | hx
  Tensor ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  -- | data
  Tensor ->
  (Tensor, Tensor)
gru' :: Tensor
-> Tensor
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> Tensor
-> (Tensor, Tensor)
gru' Tensor
_batch_sizes Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Tensor
_data = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.gru_tttlbldbb Tensor
_data Tensor
_batch_sizes Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional

rnnTanh ::
  -- | hx
  Tensor ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  -- | batch_first
  Bool ->
  -- | input
  Tensor ->
  (Tensor, Tensor)
rnnTanh :: Tensor
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> Bool
-> Tensor
-> (Tensor, Tensor)
rnnTanh Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first Tensor
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.rnn_tanh_ttlbldbbb Tensor
_input Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first

rnnTanh' ::
  -- | batch_sizes
  Tensor ->
  -- | hx
  Tensor ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  -- | data
  Tensor ->
  (Tensor, Tensor)
rnnTanh' :: Tensor
-> Tensor
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> Tensor
-> (Tensor, Tensor)
rnnTanh' Tensor
_batch_sizes Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Tensor
_data = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.rnn_tanh_tttlbldbb Tensor
_data Tensor
_batch_sizes Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional

rnnRelu ::
  -- | hx
  Tensor ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  -- | batch_first
  Bool ->
  -- | input
  Tensor ->
  (Tensor, Tensor)
rnnRelu :: Tensor
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> Bool
-> Tensor
-> (Tensor, Tensor)
rnnRelu Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first Tensor
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.rnn_relu_ttlbldbbb Tensor
_input Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional Bool
_batch_first

rnnRelu' ::
  -- | data
  Tensor ->
  -- | batch_sizes
  Tensor ->
  -- | hx
  Tensor ->
  -- | params
  [Tensor] ->
  -- | has_biases
  Bool ->
  -- | num_layers
  Int ->
  -- | dropout
  Double ->
  -- | train
  Bool ->
  -- | bidirectional
  Bool ->
  (Tensor, Tensor)
rnnRelu' :: Tensor
-> Tensor
-> Tensor
-> [Tensor]
-> Bool
-> Int
-> Double
-> Bool
-> Bool
-> (Tensor, Tensor)
rnnRelu' Tensor
_data Tensor
_batch_sizes Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable y cy) =>
(ca
 -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
cast9 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.rnn_relu_tttlbldbb Tensor
_data Tensor
_batch_sizes Tensor
_hx [Tensor]
_params Bool
_has_biases Int
_num_layers Double
_dropout Bool
_train Bool
_bidirectional

-- | A long short-term memory (LSTM) cell.
lstmCell ::
  -- | input-hidden weights (4*hidden_size, input_size)
  Tensor ->
  -- | hidden-hidden weights (4*hidden_size, hidden_size)
  Tensor ->
  -- | input-hidden bias (4*hidden_size)
  Tensor ->
  -- | hidden-hidden bias, of shape (4*hidden_size)
  Tensor ->
  -- | hidden state
  (Tensor, Tensor) ->
  -- | input
  Tensor ->
  (Tensor, Tensor) -- next hidden state, next cell state
lstmCell :: Tensor
-> Tensor
-> Tensor
-> Tensor
-> (Tensor, Tensor)
-> Tensor
-> (Tensor, Tensor)
lstmCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh (Tensor
_hx, Tensor
_cx) Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6
      ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.lstm_cell_tltttt
      Tensor
_input
      ([Tensor
_hx, Tensor
_cx] :: [Tensor])
      Tensor
_w_ih
      Tensor
_w_hh
      Tensor
_b_ih
      Tensor
_b_hh -- TODO: make cast work with 2-tuples

-- | A gated recurrent unit (GRU) cell
gruCell ::
  -- | input-hidden weights
  Tensor ->
  -- | hidden-hidden weights
  Tensor ->
  -- | input-hidden bias
  Tensor ->
  -- | hidden-hidden bias
  Tensor ->
  -- | hidden state
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
gruCell :: Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
gruCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_hx Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.gru_cell_tttttt
      Tensor
_input
      Tensor
_hx
      Tensor
_w_ih
      Tensor
_w_hh
      Tensor
_b_ih
      Tensor
_b_hh

-- | An Elman RNN cell with tanh non-linearity
rnnTanhCell ::
  -- | input-hidden weights
  Tensor ->
  -- | hidden-hidden weights
  Tensor ->
  -- | input-hidden bias
  Tensor ->
  -- | hidden-hidden bias
  Tensor ->
  -- | hidden state
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
rnnTanhCell :: Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
rnnTanhCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_hx Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.rnn_tanh_cell_tttttt Tensor
_input Tensor
_hx Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh

-- | An Elman RNN cell with ReLU non-linearity
rnnReluCell ::
  -- | input-hidden weights
  Tensor ->
  -- | hidden-hidden weights
  Tensor ->
  -- | input-hidden bias
  Tensor ->
  -- | hidden-hidden bias
  Tensor ->
  -- | hidden state
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
rnnReluCell :: Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
rnnReluCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_hx Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.rnn_relu_cell_tttttt Tensor
_input Tensor
_hx Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh

-- | A quantized long short-term memory (LSTM) cell.
quantizedLstmCell ::
  -- | input-hidden weights
  Tensor ->
  -- | hidden-hidden weights
  Tensor ->
  -- | input-hidden bias
  Tensor ->
  -- | hidden-hidden bias
  Tensor ->
  -- | input-hidden packed
  Tensor ->
  -- | hidden-hidden packed
  Tensor ->
  -- | input-hidden column offsets
  Tensor ->
  -- | hidden-hidden column offsets
  Tensor ->
  -- | input-hidden scale
  Float ->
  -- | hidden-hidden scale
  Float ->
  -- | input-hidden zero point
  Float ->
  -- | hidden-hidden zero point
  Float ->
  -- | hidden state
  (Tensor, Tensor) ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor)
quantizedLstmCell :: Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Float
-> Float
-> Float
-> Float
-> (Tensor, Tensor)
-> Tensor
-> (Tensor, Tensor)
quantizedLstmCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_packed_ih Tensor
_packed_hh Tensor
_col_offsets_ih Tensor
_col_offsets_hh Float
_scale_ih Float
_scale_hh Float
_zero_point_ih Float
_zero_point_hh (Tensor
_hx, Tensor
_cx) Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       x9 cx9 x10 cx10 x11 cx11 x12 cx12 x13 cx13 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable x9 cx9, Castable x10 cx10,
 Castable x11 cx11, Castable x12 cx12, Castable x13 cx13,
 Castable y cy) =>
(ca
 -> cx1
 -> cx2
 -> cx3
 -> cx4
 -> cx5
 -> cx6
 -> cx7
 -> cx8
 -> cx9
 -> cx10
 -> cx11
 -> cx12
 -> cx13
 -> IO cy)
-> a
-> x1
-> x2
-> x3
-> x4
-> x5
-> x6
-> x7
-> x8
-> x9
-> x10
-> x11
-> x12
-> x13
-> IO y
cast14
      ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.quantized_lstm_cell_tlttttttttssss
      Tensor
_input
      ([Tensor
_hx, Tensor
_cx] :: [Tensor])
      Tensor
_w_ih
      Tensor
_w_hh
      Tensor
_b_ih
      Tensor
_b_hh
      Tensor
_packed_ih
      Tensor
_packed_hh
      Tensor
_col_offsets_ih
      Tensor
_col_offsets_hh
      Float
_scale_ih
      Float
_scale_hh
      Float
_zero_point_ih
      Float
_zero_point_hh

-- | A quantized long gated recurrent unit (GRU) cell.
quantizedGruCell ::
  -- | input-hidden weights
  Tensor ->
  -- | hidden-hidden weights
  Tensor ->
  -- | input-hidden bias
  Tensor ->
  -- | hidden-hidden bias
  Tensor ->
  -- | input-hidden packed
  Tensor ->
  -- | hidden-hidden packed
  Tensor ->
  -- | input-hidden column offsets
  Tensor ->
  -- | hidden-hidden column offsets
  Tensor ->
  -- | input-hidden scale
  Float ->
  -- | hidden-hidden scale
  Float ->
  -- | input-hidden zero point
  Float ->
  -- | hidden-hidden zero point
  Float ->
  -- | hidden state
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
quantizedGruCell :: Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Float
-> Float
-> Float
-> Float
-> Tensor
-> Tensor
-> Tensor
quantizedGruCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_packed_ih Tensor
_packed_hh Tensor
_col_offsets_ih Tensor
_col_offsets_hh Float
_scale_ih Float
_scale_hh Float
_zero_point_ih Float
_zero_point_hh Tensor
_hx Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       x9 cx9 x10 cx10 x11 cx11 x12 cx12 x13 cx13 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable x9 cx9, Castable x10 cx10,
 Castable x11 cx11, Castable x12 cx12, Castable x13 cx13,
 Castable y cy) =>
(ca
 -> cx1
 -> cx2
 -> cx3
 -> cx4
 -> cx5
 -> cx6
 -> cx7
 -> cx8
 -> cx9
 -> cx10
 -> cx11
 -> cx12
 -> cx13
 -> IO cy)
-> a
-> x1
-> x2
-> x3
-> x4
-> x5
-> x6
-> x7
-> x8
-> x9
-> x10
-> x11
-> x12
-> x13
-> IO y
cast14 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.quantized_gru_cell_ttttttttttssss Tensor
_input Tensor
_hx Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_packed_ih Tensor
_packed_hh Tensor
_col_offsets_ih Tensor
_col_offsets_hh Float
_scale_ih Float
_scale_hh Float
_zero_point_ih Float
_zero_point_hh

-- | A quantized Elman RNN cell with relu non-linearity
quantizedRnnReluCell ::
  -- | input-hidden weights
  Tensor ->
  -- | hidden-hidden weights
  Tensor ->
  -- | input-hidden bias
  Tensor ->
  -- | hidden-hidden bias
  Tensor ->
  -- | input-hidden packed
  Tensor ->
  -- | hidden-hidden packed
  Tensor ->
  -- | input-hidden column offsets
  Tensor ->
  -- | hidden-hidden column offsets
  Tensor ->
  -- | input-hidden scale
  Float ->
  -- | hidden-hidden scale
  Float ->
  -- | input-hidden zero point
  Float ->
  -- | hidden-hidden zero point
  Float ->
  -- | hidden state
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
quantizedRnnReluCell :: Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Float
-> Float
-> Float
-> Float
-> Tensor
-> Tensor
-> Tensor
quantizedRnnReluCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_packed_ih Tensor
_packed_hh Tensor
_col_offsets_ih Tensor
_col_offsets_hh Float
_scale_ih Float
_scale_hh Float
_zero_point_ih Float
_zero_point_hh Tensor
_hx Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       x9 cx9 x10 cx10 x11 cx11 x12 cx12 x13 cx13 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable x9 cx9, Castable x10 cx10,
 Castable x11 cx11, Castable x12 cx12, Castable x13 cx13,
 Castable y cy) =>
(ca
 -> cx1
 -> cx2
 -> cx3
 -> cx4
 -> cx5
 -> cx6
 -> cx7
 -> cx8
 -> cx9
 -> cx10
 -> cx11
 -> cx12
 -> cx13
 -> IO cy)
-> a
-> x1
-> x2
-> x3
-> x4
-> x5
-> x6
-> x7
-> x8
-> x9
-> x10
-> x11
-> x12
-> x13
-> IO y
cast14 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.quantized_rnn_relu_cell_ttttttttttssss Tensor
_input Tensor
_hx Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_packed_ih Tensor
_packed_hh Tensor
_col_offsets_ih Tensor
_col_offsets_hh Float
_scale_ih Float
_scale_hh Float
_zero_point_ih Float
_zero_point_hh

-- | A quantized Elman RNN cell with tanh non-linearity
quantizedRnnTanhCell ::
  -- | input-hidden weights
  Tensor ->
  -- | hidden-hidden weights
  Tensor ->
  -- | input-hidden bias
  Tensor ->
  -- | hidden-hidden bias
  Tensor ->
  -- | input-hidden packed
  Tensor ->
  -- | hidden-hidden packed
  Tensor ->
  -- | input-hidden column offsets
  Tensor ->
  -- | hidden-hidden column offsets
  Tensor ->
  -- | input-hidden scale
  Float ->
  -- | hidden-hidden scale
  Float ->
  -- | input-hidden zero point
  Float ->
  -- | hidden-hidden zero point
  Float ->
  -- | hidden state
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
quantizedRnnTanhCell :: Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Tensor
-> Float
-> Float
-> Float
-> Float
-> Tensor
-> Tensor
-> Tensor
quantizedRnnTanhCell Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_packed_ih Tensor
_packed_hh Tensor
_col_offsets_ih Tensor
_col_offsets_hh Float
_scale_ih Float
_scale_hh Float
_zero_point_ih Float
_zero_point_hh Tensor
_hx Tensor
_input =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
       x9 cx9 x10 cx10 x11 cx11 x12 cx12 x13 cx13 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable x8 cx8, Castable x9 cx9, Castable x10 cx10,
 Castable x11 cx11, Castable x12 cx12, Castable x13 cx13,
 Castable y cy) =>
(ca
 -> cx1
 -> cx2
 -> cx3
 -> cx4
 -> cx5
 -> cx6
 -> cx7
 -> cx8
 -> cx9
 -> cx10
 -> cx11
 -> cx12
 -> cx13
 -> IO cy)
-> a
-> x1
-> x2
-> x3
-> x4
-> x5
-> x6
-> x7
-> x8
-> x9
-> x10
-> x11
-> x12
-> x13
-> IO y
cast14 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.quantized_rnn_tanh_cell_ttttttttttssss Tensor
_input Tensor
_hx Tensor
_w_ih Tensor
_w_hh Tensor
_b_ih Tensor
_b_hh Tensor
_packed_ih Tensor
_packed_hh Tensor
_col_offsets_ih Tensor
_col_offsets_hh Float
_scale_ih Float
_scale_hh Float
_zero_point_ih Float
_zero_point_hh

-- | Applies the soft shrinkage function elementwise
softShrink ::
  -- | lambda
  Float ->
  -- | input
  Tensor ->
  -- | output
  Tensor
softShrink :: Float -> Tensor -> Tensor
softShrink Float
lambda Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.softshrink_ts Tensor
input Float
lambda

-- | Concatenates sequence of tensors along a new dimension.
-- All tensors need to be of the same size.
stack ::
  -- | dim
  Dim ->
  -- | input
  [Tensor] ->
  -- | output
  Tensor
stack :: Dim -> [Tensor] -> Tensor
stack (Dim Int
d) [Tensor]
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.stack_ll [Tensor]
tensors Int
d

-- | Returns the sum of each row of the input tensor in the given dimension dim.
-- If keepdim is True, the output tensor is of the same size as input except in the dimension(s) dim where it is of size 1.
-- Otherwise, dim is squeezed, resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).
sumDim ::
  -- | dim to sum along
  Dim ->
  -- | whether the output tensor has dim retained or not
  KeepDim ->
  -- | datatype
  DType ->
  -- | input
  Tensor ->
  -- | output
  Tensor
sumDim :: Dim -> KeepDim -> DType -> Tensor -> Tensor
sumDim (Dim Int
d) KeepDim
k DType
dtype Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr IntArray
-> CBool
-> ScalarType
-> IO (ForeignPtr Tensor)
ATen.sum_tlbs Tensor
input Int
d (KeepDim -> Bool
keepdim KeepDim
k) DType
dtype

-- | Returns the k largest elements of the given input tensor along a given dimension.
-- If largest is False then the k smallest elements are returned.
-- The boolean option sorted if True, will make sure that the returned k elements are themselves sorted
-- A tuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.
topK ::
  -- | k
  Int ->
  -- | dim to find topK along
  Dim ->
  -- | largest
  Bool ->
  -- | sorted
  Bool ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor)
topK :: Int -> Dim -> Bool -> Bool -> Tensor -> (Tensor, Tensor)
topK Int
k (Dim Int
d) Bool
largest Bool
sorted Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> Int64
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.topk_tllbb Tensor
input Int
k Int
d Bool
largest Bool
sorted

-- | Returns the log of summed exponentials of each row of the input tensor in the given dimension dim. The computation is numerically stabilized.
logsumexp ::
  -- | keepdim
  Bool ->
  -- | dim
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
logsumexp :: Bool -> Int -> Tensor -> Tensor
logsumexp Bool
keepdim Int
dim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.logsumexp_tlb Tensor
t Int
dim Bool
keepdim

-- | Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
-- The upper triangular part of the matrix is defined as the elements on and above the diagonal.
-- The argument diagonal controls which diagonal to consider. If diagonal = 0, all elements on and above the main diagonal are retained.
-- A positive value excludes just as many diagonals above the main diagonal, and similarly a negative value includes just as many diagonals below the main diagonal.
-- The main diagonal are the set of indices \((i,i)\) for \(i\) \(\in [0,\min(d_1,d_2)-1]\) where \(d_1\) and \(d_2 \) are the dimensions of the matrix.
triu ::
  -- | diagonal
  Diag ->
  -- | input
  Tensor ->
  -- | output
  Tensor
triu :: Diag -> Tensor -> Tensor
triu (Diag Int
diagonal) Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.triu_tl Tensor
input Int
diagonal

-- | Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
-- The lower triangular part of the matrix is defined as the elements on and below the diagonal.
-- The argument diagonal controls which diagonal to consider. If diagonal = 0, all elements on and below the main diagonal are retained.
-- A positive value includes just as many diagonals above the main diagonal, and similarly a negative value excludes just as many diagonals below the main diagonal.
-- The main diagonals are the set of indices \((i,i)\) for \(i\) \(\in [0,\min(d_1,d_2)-1]\) where \(d_1\) and \(d_2 \) are the dimensions of the matrix.
tril ::
  -- | diagonal
  Diag ->
  -- | input
  Tensor ->
  -- | output
  Tensor
tril :: Diag -> Tensor -> Tensor
tril (Diag Int
diagonal) Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.tril_tl Tensor
input Int
diagonal

-- | Returns a new tensor with the truncated integer values of the elements of input.
trunc ::
  -- | input
  Tensor ->
  -- | output
  Tensor
trunc :: Tensor -> Tensor
trunc Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.trunc_t Tensor
input

-- | Returns the unique elements of the input tensor along a dimension.
uniqueDim ::
  -- | dim
  Int ->
  -- | sorted
  Bool ->
  -- | return_inverse
  Bool ->
  -- | return_counts
  Bool ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor, Tensor)
uniqueDim :: Int -> Bool -> Bool -> Bool -> Tensor -> (Tensor, Tensor, Tensor)
uniqueDim Int
dim Bool
sorted Bool
returnInverse Bool
returnCounts Tensor
self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.unique_dim_tlbbb Tensor
self Int
dim Bool
sorted Bool
returnInverse Bool
returnCounts

-- | Eliminates all but the first element from every consecutive group of equivalent elements.
-- This function is different from uniqueDim in the sense that this function only eliminates consecutive duplicate values.
uniqueConsecutive ::
  -- | return_inverse
  Bool ->
  -- | return_counts
  Bool ->
  -- | dim
  Int ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor, Tensor)
uniqueConsecutive :: Bool -> Bool -> Int -> Tensor -> (Tensor, Tensor, Tensor)
uniqueConsecutive Bool
returnInverse Bool
returnCounts Int
dim Tensor
self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> CBool
-> CBool
-> Int64
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.unique_consecutive_tbbl Tensor
self Bool
returnInverse Bool
returnCounts Int
dim

-- | Eliminates all but the first element from every consecutive group of equivalent elements along a dimension.
-- This function is different from uniqueDim in the sense that this function only eliminates consecutive duplicate values.
uniqueDimConsecutive ::
  -- | dim
  Int ->
  -- | return_inverse
  Bool ->
  -- | return_counts
  Bool ->
  -- | input
  Tensor ->
  -- | output
  (Tensor, Tensor, Tensor)
uniqueDimConsecutive :: Int -> Bool -> Bool -> Tensor -> (Tensor, Tensor, Tensor)
uniqueDimConsecutive Int
dim Bool
returnInverse Bool
returnCounts Tensor
self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.unique_dim_consecutive_tlbb Tensor
self Int
dim Bool
returnInverse Bool
returnCounts

-- | Returns a new tensor with a dimension of size one inserted at the specified position.
-- The returned tensor shares the same underlying data with this tensor.
-- A dim value within the range [(dim input) - 1, (dim input) + 1)] can be used. Negative dim will correspond to unsqueeze applied at dim = dim + (dim input) + 1
unsqueeze ::
  -- | dim
  Dim ->
  -- | input
  Tensor ->
  -- | output
  Tensor
unsqueeze :: Dim -> Tensor -> Tensor
unsqueeze (Dim Int
d) Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.unsqueeze_tl Tensor
input Int
d

-- | Upsamples the input, using bilinear upsampling. Expec