{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}

module Torch.Internal.Unmanaged.Type.Extra where


import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Cpp.Unsafe as C
import qualified Language.C.Inline.Context as C
import qualified Language.C.Types as C
import qualified Data.Map as Map
import Foreign.C.String
import Foreign.C.Types
import Foreign
import Torch.Internal.Type


C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable }

C.include "<ATen/Functions.h>"
C.include "<ATen/Tensor.h>"
C.include "<ATen/TensorOperators.h>"
C.include "<vector>"
C.include "<torch/csrc/autograd/generated/variable_factories.h>"

tensor_assign1_l
  :: Ptr Tensor
  -> Int64
  -> Int64
  -> IO ()
tensor_assign1_l :: Ptr Tensor -> Int64 -> Int64 -> IO ()
tensor_assign1_l Ptr Tensor
_obj Int64
_idx0 Int64
_val  =
  [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)] = $(int64_t _val); }|]

tensor_assign2_l
  :: Ptr Tensor
  -> Int64
  -> Int64
  -> Int64
  -> IO ()
tensor_assign2_l :: Ptr Tensor -> Int64 -> Int64 -> Int64 -> IO ()
tensor_assign2_l Ptr Tensor
_obj Int64
_idx0 Int64
_idx1 Int64
_val  =
  [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)][$(int64_t _idx1)] = $(int64_t _val); }|]

tensor_assign1_d
  :: Ptr Tensor
  -> Int64
  -> CDouble
  -> IO ()
tensor_assign1_d :: Ptr Tensor -> Int64 -> CDouble -> IO ()
tensor_assign1_d Ptr Tensor
_obj Int64
_idx0 CDouble
_val  =
  [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)] = $(double _val); }|]

tensor_assign2_d
  :: Ptr Tensor
  -> Int64
  -> Int64
  -> CDouble
  -> IO ()
tensor_assign2_d :: Ptr Tensor -> Int64 -> Int64 -> CDouble -> IO ()
tensor_assign2_d Ptr Tensor
_obj Int64
_idx0 Int64
_idx1 CDouble
_val  =
  [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)][$(int64_t _idx1)] = $(double _val); }|]


tensor_assign1_t
  :: Ptr Tensor
  -> Int64
  -> Ptr Tensor
  -> IO ()
tensor_assign1_t :: Ptr Tensor -> Int64 -> Ptr Tensor -> IO ()
tensor_assign1_t Ptr Tensor
_obj Int64
_idx0 Ptr Tensor
_val  =
  [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)] = *$(at::Tensor* _val); }|]

tensor_assign2_t
  :: Ptr Tensor
  -> Int64
  -> Int64
  -> Ptr Tensor
  -> IO ()
tensor_assign2_t :: Ptr Tensor -> Int64 -> Int64 -> Ptr Tensor -> IO ()
tensor_assign2_t Ptr Tensor
_obj Int64
_idx0 Int64
_idx1 Ptr Tensor
_val  =
  [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)][$(int64_t _idx1)] = *$(at::Tensor* _val); }|]


tensor_names
  :: Ptr Tensor
  -> IO (Ptr DimnameList)
tensor_names :: Ptr Tensor -> IO (Ptr DimnameList)
tensor_names Ptr Tensor
_obj =
  [C.throwBlock| std::vector<at::Dimname>* {
      auto ref = (*$(at::Tensor* _obj)).names();
      std::vector<at::Dimname>* vec = new std::vector<at::Dimname>();
      for(int i=0;i<ref.size();i++){
        vec->push_back(ref[i]);
      }
      return vec;
  }|]

tensor_to_device
  :: Ptr Tensor
  -> Ptr Tensor
  -> IO (Ptr Tensor)
tensor_to_device :: Ptr Tensor -> Ptr Tensor -> IO (Ptr Tensor)
tensor_to_device Ptr Tensor
reference Ptr Tensor
input =
  [C.throwBlock| at::Tensor* {
      auto d = (*$(at::Tensor* reference)).device();
      return new at::Tensor((*$(at::Tensor* input)).to(d));
  }|]

new_empty_tensor
  :: [Int]
  -> Ptr TensorOptions
  -> IO (Ptr Tensor)
new_empty_tensor :: [Int] -> Ptr TensorOptions -> IO (Ptr Tensor)
new_empty_tensor [Int
x] Ptr TensorOptions
_options = do
  let x' :: CInt
x' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x
  [C.throwBlock| at::Tensor* {
    return new at::Tensor(torch::empty({$(int x')}, *$(at::TensorOptions* _options)));
  }|]

new_empty_tensor [Int
x,Int
y] Ptr TensorOptions
_options = do
  let x' :: CInt
x' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x
      y' :: CInt
y' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
y
  [C.throwBlock| at::Tensor* {
    return new at::Tensor(torch::empty({$(int x'),$(int y')}, *$(at::TensorOptions* _options)));
  }|]

new_empty_tensor [Int
x,Int
y,Int
z] Ptr TensorOptions
_options = do
  let x' :: CInt
x' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x
      y' :: CInt
y' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
y
      z' :: CInt
z' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
z
  [C.throwBlock| at::Tensor* {
    return new at::Tensor(torch::empty({$(int x'),$(int y'),$(int z')}, *$(at::TensorOptions* _options)));
  }|]

new_empty_tensor [Int]
_size Ptr TensorOptions
_options = do
  let len :: CInt
len = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
_size
  Ptr (StdVector Int64)
shape <- [C.throwBlock| std::vector<int64_t>* {
    return new std::vector<int64_t>($(int len));
  }|]
  Ptr Int64
ptr <- [C.throwBlock| int64_t* {
    return $(std::vector<int64_t>* shape)->data();
  }|]
  forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr Int64
ptr (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
_size)

  [C.throwBlock| at::Tensor* {
    auto v = new at::Tensor(torch::empty(*$(std::vector<int64_t>* shape), *$(at::TensorOptions* _options)));
    delete $(std::vector<int64_t>* shape);
    return v;
  }|]