{-# LANGUAGE FlexibleContexts #-}

module Torch.TensorFactories where

import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.Dimname
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Autograd as LibTorch
import Torch.Internal.Managed.Cast
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Scalar
import Torch.Tensor
import Torch.TensorOptions

-- XXX: We use the torch:: constructors, not at:: constructures, because
--      otherwise we cannot use libtorch's AD.

type FactoryType =
  ForeignPtr ATen.IntArray ->
  ForeignPtr ATen.TensorOptions ->
  IO (ForeignPtr ATen.Tensor)

type FactoryTypeWithDimnames =
  ForeignPtr ATen.IntArray ->
  ForeignPtr ATen.DimnameList ->
  ForeignPtr ATen.TensorOptions ->
  IO (ForeignPtr ATen.Tensor)

mkFactory ::
  -- | aten_impl
  FactoryType ->
  -- | shape
  [Int] ->
  -- | opts
  TensorOptions ->
  -- | output
  IO Tensor
mkFactory :: FactoryType -> [Int] -> TensorOptions -> IO Tensor
mkFactory = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2

mkFactoryUnsafe :: FactoryType -> [Int] -> TensorOptions -> Tensor
mkFactoryUnsafe :: FactoryType -> [Int] -> TensorOptions -> Tensor
mkFactoryUnsafe FactoryType
f [Int]
shape TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ FactoryType -> [Int] -> TensorOptions -> IO Tensor
mkFactory FactoryType
f [Int]
shape TensorOptions
opts

mkFactoryWithDimnames :: FactoryTypeWithDimnames -> [(Int, Dimname)] -> TensorOptions -> IO Tensor
mkFactoryWithDimnames :: FactoryTypeWithDimnames
-> [(Int, Dimname)] -> TensorOptions -> IO Tensor
mkFactoryWithDimnames FactoryTypeWithDimnames
aten_impl [(Int, Dimname)]
shape = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 FactoryTypeWithDimnames
aten_impl (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Int, Dimname)]
shape) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Int, Dimname)]
shape)

mkFactoryUnsafeWithDimnames :: FactoryTypeWithDimnames -> [(Int, Dimname)] -> TensorOptions -> Tensor
mkFactoryUnsafeWithDimnames :: FactoryTypeWithDimnames
-> [(Int, Dimname)] -> TensorOptions -> Tensor
mkFactoryUnsafeWithDimnames FactoryTypeWithDimnames
f [(Int, Dimname)]
shape TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ FactoryTypeWithDimnames
-> [(Int, Dimname)] -> TensorOptions -> IO Tensor
mkFactoryWithDimnames FactoryTypeWithDimnames
f [(Int, Dimname)]
shape TensorOptions
opts

mkDefaultFactory :: ([Int] -> TensorOptions -> a) -> [Int] -> a
mkDefaultFactory :: forall a. ([Int] -> TensorOptions -> a) -> [Int] -> a
mkDefaultFactory [Int] -> TensorOptions -> a
non_default [Int]
shape = [Int] -> TensorOptions -> a
non_default [Int]
shape TensorOptions
defaultOpts

mkDefaultFactoryWithDimnames :: ([(Int, Dimname)] -> TensorOptions -> a) -> [(Int, Dimname)] -> a
mkDefaultFactoryWithDimnames :: forall a.
([(Int, Dimname)] -> TensorOptions -> a) -> [(Int, Dimname)] -> a
mkDefaultFactoryWithDimnames [(Int, Dimname)] -> TensorOptions -> a
non_default [(Int, Dimname)]
shape = [(Int, Dimname)] -> TensorOptions -> a
non_default [(Int, Dimname)]
shape TensorOptions
defaultOpts

-------------------- Factories --------------------

-- | Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size.
ones ::
  -- | sequence of integers defining the shape of the output tensor.
  [Int] ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  Tensor
ones :: [Int] -> TensorOptions -> Tensor
ones = FactoryType -> [Int] -> TensorOptions -> Tensor
mkFactoryUnsafe FactoryType
LibTorch.ones_lo

-- TODO - ones_like from Native.hs is redundant with this

-- | Returns a tensor filled with the scalar value 1, with the same size as input tensor
onesLike ::
  -- | input
  Tensor ->
  -- | output
  Tensor
onesLike :: Tensor -> Tensor
onesLike Tensor
self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.ones_like_t Tensor
self

-- | Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size.
zeros ::
  -- | sequence of integers defining the shape of the output tensor.
  [Int] ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  Tensor
zeros :: [Int] -> TensorOptions -> Tensor
zeros = FactoryType -> [Int] -> TensorOptions -> Tensor
mkFactoryUnsafe FactoryType
LibTorch.zeros_lo

-- | Returns a tensor filled with the scalar value 0, with the same size as input tensor
zerosLike ::
  -- | input
  Tensor ->
  -- | output
  Tensor
zerosLike :: Tensor -> Tensor
zerosLike Tensor
self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.zeros_like_t Tensor
self

-- | Returns a tensor filled with random numbers from a uniform distribution on the interval [0,1)
randIO ::
  -- | sequence of integers defining the shape of the output tensor.
  [Int] ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  IO Tensor
randIO :: [Int] -> TensorOptions -> IO Tensor
randIO = FactoryType -> [Int] -> TensorOptions -> IO Tensor
mkFactory FactoryType
LibTorch.rand_lo

-- | Returns a tensor filled with random numbers from a standard normal distribution.
randnIO ::
  -- | sequence of integers defining the shape of the output tensor.
  [Int] ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  IO Tensor
randnIO :: [Int] -> TensorOptions -> IO Tensor
randnIO = FactoryType -> [Int] -> TensorOptions -> IO Tensor
mkFactory FactoryType
LibTorch.randn_lo

-- | Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive).
randintIO ::
  -- | lowest integer to be drawn from the distribution. Default: 0.
  Int ->
  -- | one above the highest integer to be drawn from the distribution.
  Int ->
  -- | the shape of the output tensor.
  [Int] ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  IO Tensor
randintIO :: Int -> Int -> [Int] -> TensorOptions -> IO Tensor
randintIO Int
low Int
high = FactoryType -> [Int] -> TensorOptions -> IO Tensor
mkFactory (Int64 -> Int64 -> FactoryType
LibTorch.randint_lllo (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
low) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
high))

-- | Returns a tensor with the same size as input that is filled with random numbers from standard normal distribution.
randnLikeIO ::
  -- | input
  Tensor ->
  -- | output
  IO Tensor
randnLikeIO :: Tensor -> IO Tensor
randnLikeIO = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.randn_like_t

-- | Returns a tensor with the same size as input that is filled with random numbers from a uniform distribution on the interval [0,1).
randLikeIO ::
  -- | input
  Tensor ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  IO Tensor
randLikeIO :: Tensor -> TensorOptions -> IO Tensor
randLikeIO = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.rand_like_to

fullLike ::
  -- | input
  Tensor ->
  -- | _fill_value
  Float ->
  -- | opt
  TensorOptions ->
  -- | output
  IO Tensor
fullLike :: Tensor -> Float -> TensorOptions -> IO Tensor
fullLike = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
LibTorch.full_like_tso

onesWithDimnames :: [(Int, Dimname)] -> TensorOptions -> Tensor
onesWithDimnames :: [(Int, Dimname)] -> TensorOptions -> Tensor
onesWithDimnames = FactoryTypeWithDimnames
-> [(Int, Dimname)] -> TensorOptions -> Tensor
mkFactoryUnsafeWithDimnames FactoryTypeWithDimnames
LibTorch.ones_lNo

zerosWithDimnames :: [(Int, Dimname)] -> TensorOptions -> Tensor
zerosWithDimnames :: [(Int, Dimname)] -> TensorOptions -> Tensor
zerosWithDimnames = FactoryTypeWithDimnames
-> [(Int, Dimname)] -> TensorOptions -> Tensor
mkFactoryUnsafeWithDimnames FactoryTypeWithDimnames
LibTorch.zeros_lNo

randWithDimnames :: [(Int, Dimname)] -> TensorOptions -> IO Tensor
randWithDimnames :: [(Int, Dimname)] -> TensorOptions -> IO Tensor
randWithDimnames = FactoryTypeWithDimnames
-> [(Int, Dimname)] -> TensorOptions -> IO Tensor
mkFactoryWithDimnames FactoryTypeWithDimnames
LibTorch.rand_lNo

randnWithDimnames :: [(Int, Dimname)] -> TensorOptions -> IO Tensor
randnWithDimnames :: [(Int, Dimname)] -> TensorOptions -> IO Tensor
randnWithDimnames = FactoryTypeWithDimnames
-> [(Int, Dimname)] -> TensorOptions -> IO Tensor
mkFactoryWithDimnames FactoryTypeWithDimnames
LibTorch.randn_lNo

-- | Returns a one-dimensional tensor of steps equally spaced points between start and end.
linspace ::
  (Scalar a, Scalar b) =>
  -- | @start@
  a ->
  -- | @end@
  b ->
  -- | @steps@
  Int ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  Tensor
linspace :: forall a b.
(Scalar a, Scalar b) =>
a -> b -> Int -> TensorOptions -> Tensor
linspace a
start b
end Int
steps TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Scalar
-> ForeignPtr Scalar
-> Int64
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
LibTorch.linspace_sslo a
start b
end Int
steps TensorOptions
opts

logspace :: (Scalar a, Scalar b) => a -> b -> Int -> Double -> TensorOptions -> Tensor
logspace :: forall a b.
(Scalar a, Scalar b) =>
a -> b -> Int -> Double -> TensorOptions -> Tensor
logspace a
start b
end Int
steps Double
base TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Scalar
-> ForeignPtr Scalar
-> Int64
-> CDouble
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
LibTorch.logspace_ssldo a
start b
end Int
steps Double
base TensorOptions
opts

-- https://github.com/hasktorch/ffi-experimental/pull/57#discussion_r301062033
-- empty :: [Int] -> TensorOptions -> Tensor
-- empty = mkFactoryUnsafe LibTorch.empty_lo

eyeSquare ::
  -- | dim
  Int ->
  -- | opts
  TensorOptions ->
  -- | output
  Tensor
eyeSquare :: Int -> TensorOptions -> Tensor
eyeSquare Int
dim = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 Int64 -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.eye_lo Int
dim

-- | Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.
eye ::
  -- | the number of rows
  Int ->
  -- | the number of columns
  Int ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  Tensor
eye :: Int -> Int -> TensorOptions -> Tensor
eye Int
nrows Int
ncols TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 Int64
-> Int64 -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.eye_llo Int
nrows Int
ncols TensorOptions
opts

-- | Returns a tensor of given size filled with fill_value.
full ::
  Scalar a =>
  -- | the shape of the output tensor.
  [Int] ->
  -- | the number to fill the output tensor with
  a ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  Tensor
full :: forall a. Scalar a => [Int] -> a -> TensorOptions -> Tensor
full [Int]
shape a
value TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
LibTorch.full_lso [Int]
shape a
value TensorOptions
opts

-- | Constructs a sparse tensors in COO(rdinate) format with non-zero elements at the given indices with the given values.
sparseCooTensor ::
  -- | The indices are the coordinates of the non-zero values in the matrix
  Tensor ->
  -- | Initial values for the tensor.
  Tensor ->
  -- | the shape of the output tensor.
  [Int] ->
  -- |
  TensorOptions ->
  -- | output
  Tensor
sparseCooTensor :: Tensor -> Tensor -> [Int] -> TensorOptions -> Tensor
sparseCooTensor Tensor
indices Tensor
values [Int]
size TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor -> ForeignPtr Tensor -> FactoryType
sparse_coo_tensor_ttlo Tensor
indices Tensor
values [Int]
size TensorOptions
opts
  where
    sparse_coo_tensor_ttlo :: ForeignPtr Tensor -> ForeignPtr Tensor -> FactoryType
sparse_coo_tensor_ttlo ForeignPtr Tensor
indices' ForeignPtr Tensor
values' ForeignPtr IntArray
size' ForeignPtr TensorOptions
opts' = do
      ForeignPtr Tensor
i' <- ForeignPtr Tensor -> IO (ForeignPtr Tensor)
LibTorch.dropVariable ForeignPtr Tensor
indices'
      ForeignPtr Tensor
v' <- ForeignPtr Tensor -> IO (ForeignPtr Tensor)
LibTorch.dropVariable ForeignPtr Tensor
values'
      ForeignPtr Tensor -> ForeignPtr Tensor -> FactoryType
LibTorch.sparse_coo_tensor_ttlo ForeignPtr Tensor
i' ForeignPtr Tensor
v' ForeignPtr IntArray
size' ForeignPtr TensorOptions
opts'

-------------------- Factories with default type --------------------

ones' :: [Int] -> Tensor
ones' :: [Int] -> Tensor
ones' = forall a. ([Int] -> TensorOptions -> a) -> [Int] -> a
mkDefaultFactory [Int] -> TensorOptions -> Tensor
ones

zeros' :: [Int] -> Tensor
zeros' :: [Int] -> Tensor
zeros' = forall a. ([Int] -> TensorOptions -> a) -> [Int] -> a
mkDefaultFactory [Int] -> TensorOptions -> Tensor
zeros

randIO' :: [Int] -> IO Tensor
randIO' :: [Int] -> IO Tensor
randIO' = forall a. ([Int] -> TensorOptions -> a) -> [Int] -> a
mkDefaultFactory [Int] -> TensorOptions -> IO Tensor
randIO

randnIO' :: [Int] -> IO Tensor
randnIO' :: [Int] -> IO Tensor
randnIO' = forall a. ([Int] -> TensorOptions -> a) -> [Int] -> a
mkDefaultFactory [Int] -> TensorOptions -> IO Tensor
randnIO

randintIO' :: Int -> Int -> [Int] -> IO Tensor
randintIO' :: Int -> Int -> [Int] -> IO Tensor
randintIO' Int
low Int
high = forall a. ([Int] -> TensorOptions -> a) -> [Int] -> a
mkDefaultFactory (Int -> Int -> [Int] -> TensorOptions -> IO Tensor
randintIO Int
low Int
high)

randLikeIO' :: Tensor -> IO Tensor
randLikeIO' :: Tensor -> IO Tensor
randLikeIO' Tensor
t = Tensor -> TensorOptions -> IO Tensor
randLikeIO Tensor
t TensorOptions
defaultOpts

bernoulliIO' ::
  -- | t
  Tensor ->
  -- | output
  IO Tensor
bernoulliIO' :: Tensor -> IO Tensor
bernoulliIO' = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.bernoulli_t

bernoulliIO ::
  -- | t
  Tensor ->
  -- | p
  Double ->
  -- | output
  IO Tensor
bernoulliIO :: Tensor -> Double -> IO Tensor
bernoulliIO = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CDouble -> IO (ForeignPtr Tensor)
ATen.bernoulli_td

poissonIO ::
  -- | t
  Tensor ->
  -- | output
  IO Tensor
poissonIO :: Tensor -> IO Tensor
poissonIO = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.poisson_t

multinomialIO' ::
  -- | t
  Tensor ->
  -- | num_samples
  Int ->
  -- | output
  IO Tensor
multinomialIO' :: Tensor -> Int -> IO Tensor
multinomialIO' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.multinomial_tl

multinomialIO ::
  -- | t
  Tensor ->
  -- | num_samples
  Int ->
  -- | replacement
  Bool ->
  -- | output
  IO Tensor
multinomialIO :: Tensor -> Int -> Bool -> IO Tensor
multinomialIO = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.multinomial_tlb

normalIO' ::
  -- | _mean
  Tensor ->
  -- | output
  IO Tensor
normalIO' :: Tensor -> IO Tensor
normalIO' = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.normal_t

normalIO ::
  -- | _mean
  Tensor ->
  -- | _std
  Tensor ->
  -- | output
  IO Tensor
normalIO :: Tensor -> Tensor -> IO Tensor
normalIO = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.normal_tt

normalScalarIO ::
  -- | _mean
  Tensor ->
  -- | _std
  Double ->
  -- | output
  IO Tensor
normalScalarIO :: Tensor -> Double -> IO Tensor
normalScalarIO = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CDouble -> IO (ForeignPtr Tensor)
ATen.normal_td

normalScalarIO' ::
  -- | _mean
  Double ->
  -- | _std
  Tensor ->
  -- | output
  IO Tensor
normalScalarIO' :: Double -> Tensor -> IO Tensor
normalScalarIO' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 CDouble -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.normal_dt

normalWithSizeIO ::
  -- | _mean
  Double ->
  -- | _std
  Double ->
  -- | _size
  Int ->
  -- | output
  IO Tensor
normalWithSizeIO :: Double -> Double -> Int -> IO Tensor
normalWithSizeIO = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 CDouble -> CDouble -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.normal_ddl

rreluIO''' ::
  -- | t
  Tensor ->
  -- | output
  IO Tensor
rreluIO''' :: Tensor -> IO Tensor
rreluIO''' = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.rrelu_t

rreluIO'' ::
  Scalar a =>
  -- | t
  Tensor ->
  -- | upper
  a ->
  -- | output
  IO Tensor
rreluIO'' :: forall a. Scalar a => Tensor -> a -> IO Tensor
rreluIO'' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.rrelu_ts

rreluIO' ::
  Scalar a =>
  -- | t
  Tensor ->
  -- | lower
  a ->
  -- | upper
  a ->
  -- | output
  IO Tensor
rreluIO' :: forall a. Scalar a => Tensor -> a -> a -> IO Tensor
rreluIO' = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.rrelu_tss

rreluIO ::
  Scalar a =>
  -- | t
  Tensor ->
  -- | lower
  a ->
  -- | upper
  a ->
  -- | training
  Bool ->
  -- | output
  IO Tensor
rreluIO :: forall a. Scalar a => Tensor -> a -> a -> Bool -> IO Tensor
rreluIO = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> CBool
-> IO (ForeignPtr Tensor)
ATen.rrelu_tssb

rreluWithNoiseIO''' ::
  -- | t
  Tensor ->
  -- | noise
  Tensor ->
  -- | output
  IO Tensor
rreluWithNoiseIO''' :: Tensor -> Tensor -> IO Tensor
rreluWithNoiseIO''' = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.rrelu_with_noise_tt

rreluWithNoiseIO'' ::
  Scalar a =>
  -- | t
  Tensor ->
  -- | noise
  Tensor ->
  -- | upper
  a ->
  -- | output
  IO Tensor
rreluWithNoiseIO'' :: forall a. Scalar a => Tensor -> Tensor -> a -> IO Tensor
rreluWithNoiseIO'' = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.rrelu_with_noise_tts

rreluWithNoiseIO' ::
  Scalar a =>
  -- | t
  Tensor ->
  -- | noise
  Tensor ->
  -- | lower
  a ->
  -- | upper
  a ->
  -- | output
  IO Tensor
rreluWithNoiseIO' :: forall a. Scalar a => Tensor -> Tensor -> a -> a -> IO Tensor
rreluWithNoiseIO' = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.rrelu_with_noise_ttss

rreluWithNoiseIO ::
  Scalar a =>
  -- | t
  Tensor ->
  -- | noise
  Tensor ->
  -- | lower
  a ->
  -- | upper
  a ->
  -- | training
  Bool ->
  -- | output
  IO Tensor
rreluWithNoiseIO :: forall a.
Scalar a =>
Tensor -> Tensor -> a -> a -> Bool -> IO Tensor
rreluWithNoiseIO = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> CBool
-> IO (ForeignPtr Tensor)
ATen.rrelu_with_noise_ttssb

onesWithDimnames' :: [(Int, Dimname)] -> Tensor
onesWithDimnames' :: [(Int, Dimname)] -> Tensor
onesWithDimnames' = forall a.
([(Int, Dimname)] -> TensorOptions -> a) -> [(Int, Dimname)] -> a
mkDefaultFactoryWithDimnames [(Int, Dimname)] -> TensorOptions -> Tensor
onesWithDimnames

zerosWithDimnames' :: [(Int, Dimname)] -> Tensor
zerosWithDimnames' :: [(Int, Dimname)] -> Tensor
zerosWithDimnames' = forall a.
([(Int, Dimname)] -> TensorOptions -> a) -> [(Int, Dimname)] -> a
mkDefaultFactoryWithDimnames [(Int, Dimname)] -> TensorOptions -> Tensor
zerosWithDimnames

randWithDimnames' :: [(Int, Dimname)] -> IO Tensor
randWithDimnames' :: [(Int, Dimname)] -> IO Tensor
randWithDimnames' = forall a.
([(Int, Dimname)] -> TensorOptions -> a) -> [(Int, Dimname)] -> a
mkDefaultFactoryWithDimnames [(Int, Dimname)] -> TensorOptions -> IO Tensor
randWithDimnames

randnWithDimnames' :: [(Int, Dimname)] -> IO Tensor
randnWithDimnames' :: [(Int, Dimname)] -> IO Tensor
randnWithDimnames' = forall a.
([(Int, Dimname)] -> TensorOptions -> a) -> [(Int, Dimname)] -> a
mkDefaultFactoryWithDimnames [(Int, Dimname)] -> TensorOptions -> IO Tensor
randnWithDimnames

linspace' :: (Scalar a, Scalar b) => a -> b -> Int -> Tensor
linspace' :: forall a b. (Scalar a, Scalar b) => a -> b -> Int -> Tensor
linspace' a
start b
end Int
steps = forall a b.
(Scalar a, Scalar b) =>
a -> b -> Int -> TensorOptions -> Tensor
linspace a
start b
end Int
steps TensorOptions
defaultOpts

logspace' :: (Scalar a, Scalar b) => a -> b -> Int -> Double -> Tensor
logspace' :: forall a b.
(Scalar a, Scalar b) =>
a -> b -> Int -> Double -> Tensor
logspace' a
start b
end Int
steps Double
base = forall a b.
(Scalar a, Scalar b) =>
a -> b -> Int -> Double -> TensorOptions -> Tensor
logspace a
start b
end Int
steps Double
base TensorOptions
defaultOpts

eyeSquare' :: Int -> Tensor
eyeSquare' :: Int -> Tensor
eyeSquare' Int
dim = Int -> TensorOptions -> Tensor
eyeSquare Int
dim TensorOptions
defaultOpts

eye' :: Int -> Int -> Tensor
eye' :: Int -> Int -> Tensor
eye' Int
nrows Int
ncols = Int -> Int -> TensorOptions -> Tensor
eye Int
nrows Int
ncols TensorOptions
defaultOpts

full' :: Scalar a => [Int] -> a -> Tensor
full' :: forall a. Scalar a => [Int] -> a -> Tensor
full' [Int]
shape a
value = forall a. Scalar a => [Int] -> a -> TensorOptions -> Tensor
full [Int]
shape a
value TensorOptions
defaultOpts

sparseCooTensor' :: Tensor -> Tensor -> [Int] -> Tensor
sparseCooTensor' :: Tensor -> Tensor -> [Int] -> Tensor
sparseCooTensor' Tensor
indices Tensor
values [Int]
size = Tensor -> Tensor -> [Int] -> TensorOptions -> Tensor
sparseCooTensor Tensor
indices Tensor
values [Int]
size TensorOptions
defaultOpts

-- | Returns a 1-D tensor with values from the interval [start, end) taken with common difference step beginning from start.
arange ::
  -- | start
  Int ->
  -- | end
  Int ->
  -- | step
  Int ->
  -- | configures the data type, device, layout and other properties of the resulting tensor.
  TensorOptions ->
  -- | output
  Tensor
arange :: Int -> Int -> Int -> TensorOptions -> Tensor
arange Int
s Int
e Int
ss TensorOptions
o = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.arange_ssso) Int
s Int
e Int
ss TensorOptions
o

-- | Returns a 1-D tensor with values from the interval [start, end) taken with common difference step beginning from start.
arange' ::
  -- | start
  Int ->
  -- | end
  Int ->
  -- | step
  Int ->
  -- | output
  Tensor
arange' :: Int -> Int -> Int -> Tensor
arange' Int
s Int
e Int
ss = Int -> Int -> Int -> TensorOptions -> Tensor
arange Int
s Int
e Int
ss TensorOptions
defaultOpts