-- generated by using spec/Declarations.yaml

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}

module Torch.Internal.Const where

import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Cpp.Unsafe as C
import qualified Language.C.Inline.Context as C
import qualified Language.C.Types as C
import qualified Data.Map as Map

import Foreign.C.String
import Foreign.C.Types
import Foreign
import Torch.Internal.Type

C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable }

C.include "<ATen/ScalarType.h>"
C.include "<ATen/core/Reduction.h>"
C.include "<c10/core/Layout.h>"



kByte :: ScalarType
kByte :: ScalarType
kByte = [C.pure| int8_t { (int8_t) at::ScalarType::Byte } |]

kChar :: ScalarType
kChar :: ScalarType
kChar = [C.pure| int8_t { (int8_t) at::ScalarType::Char } |]

kDouble :: ScalarType
kDouble :: ScalarType
kDouble = [C.pure| int8_t { (int8_t) at::ScalarType::Double } |]

kFloat :: ScalarType
kFloat :: ScalarType
kFloat = [C.pure| int8_t { (int8_t) at::ScalarType::Float } |]

kInt :: ScalarType
kInt :: ScalarType
kInt = [C.pure| int8_t { (int8_t) at::ScalarType::Int } |]

kLong :: ScalarType
kLong :: ScalarType
kLong = [C.pure| int8_t { (int8_t) at::ScalarType::Long } |]

kShort :: ScalarType
kShort :: ScalarType
kShort = [C.pure| int8_t { (int8_t) at::ScalarType::Short } |]

kHalf :: ScalarType
kHalf :: ScalarType
kHalf = [C.pure| int8_t { (int8_t) at::ScalarType::Half } |]

kBool :: ScalarType
kBool :: ScalarType
kBool = [C.pure| int8_t { (int8_t) at::ScalarType::Bool } |]

kComplexHalf :: ScalarType
kComplexHalf :: ScalarType
kComplexHalf = [C.pure| int8_t { (int8_t) at::ScalarType::ComplexHalf } |]

kComplexFloat :: ScalarType
kComplexFloat :: ScalarType
kComplexFloat = [C.pure| int8_t { (int8_t) at::ScalarType::ComplexFloat } |]

kComplexDouble :: ScalarType
kComplexDouble :: ScalarType
kComplexDouble = [C.pure| int8_t { (int8_t) at::ScalarType::ComplexDouble } |]

kQInt8 :: ScalarType
kQInt8 :: ScalarType
kQInt8 = [C.pure| int8_t { (int8_t) at::ScalarType::QInt8 } |]

kQUInt8 :: ScalarType
kQUInt8 :: ScalarType
kQUInt8 = [C.pure| int8_t { (int8_t) at::ScalarType::QUInt8 } |]

kQInt32 :: ScalarType
kQInt32 :: ScalarType
kQInt32 = [C.pure| int8_t { (int8_t) at::ScalarType::QInt32 } |]

kBFloat16 :: ScalarType
kBFloat16 :: ScalarType
kBFloat16 = [C.pure| int8_t { (int8_t) at::ScalarType::BFloat16 } |]

kUndefined :: ScalarType
kUndefined :: ScalarType
kUndefined = [C.pure| int8_t { (int8_t) at::ScalarType::Undefined } |]

kCPU :: DeviceType
kCPU :: DeviceType
kCPU = [C.pure| int16_t { (int16_t) at::DeviceType::CPU } |]

kCUDA :: DeviceType
kCUDA :: DeviceType
kCUDA = [C.pure| int16_t { (int16_t) at::DeviceType::CUDA } |]

kMKLDNN :: DeviceType
kMKLDNN :: DeviceType
kMKLDNN = [C.pure| int16_t { (int16_t) at::DeviceType::MKLDNN } |]

kOPENGL :: DeviceType
kOPENGL :: DeviceType
kOPENGL = [C.pure| int16_t { (int16_t) at::DeviceType::OPENGL } |]

kOPENCL :: DeviceType
kOPENCL :: DeviceType
kOPENCL = [C.pure| int16_t { (int16_t) at::DeviceType::OPENCL } |]

kIDEEP :: DeviceType
kIDEEP :: DeviceType
kIDEEP = [C.pure| int16_t { (int16_t) at::DeviceType::IDEEP } |]

kHIP :: DeviceType
kHIP :: DeviceType
kHIP = [C.pure| int16_t { (int16_t) at::DeviceType::HIP } |]

kFPGA :: DeviceType
kFPGA :: DeviceType
kFPGA = [C.pure| int16_t { (int16_t) at::DeviceType::FPGA } |]

kXLA :: DeviceType
kXLA :: DeviceType
kXLA = [C.pure| int16_t { (int16_t) at::DeviceType::XLA } |]

kVulkan :: DeviceType
kVulkan :: DeviceType
kVulkan = [C.pure| int16_t { (int16_t) at::DeviceType::Vulkan } |]

kMetal :: DeviceType
kMetal :: DeviceType
kMetal = [C.pure| int16_t { (int16_t) at::DeviceType::Metal } |]

kXPU :: DeviceType
kXPU :: DeviceType
kXPU = [C.pure| int16_t { (int16_t) at::DeviceType::XPU } |]

kCOMPILE_TIME_MAX_DEVICE_TYPES :: DeviceType
kCOMPILE_TIME_MAX_DEVICE_TYPES :: DeviceType
kCOMPILE_TIME_MAX_DEVICE_TYPES = [C.pure| int16_t { (int16_t) at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES } |]

-- TODO: add all values for at::Reduction

kMean :: Int64
kMean :: Int64
kMean = [C.pure| int64_t { (int64_t) at::Reduction::Mean } |]

bCPU :: Backend
bCPU :: Backend
bCPU = [C.pure| int { (int) at::Backend::CPU } |]

bCUDA :: Backend
bCUDA :: Backend
bCUDA = [C.pure| int { (int) at::Backend::CUDA } |]

bHIP :: Backend
bHIP :: Backend
bHIP = [C.pure| int { (int) at::Backend::HIP } |]

bSparseCPU :: Backend
bSparseCPU :: Backend
bSparseCPU = [C.pure| int { (int) at::Backend::SparseCPU } |]

bSparseCUDA :: Backend
bSparseCUDA :: Backend
bSparseCUDA = [C.pure| int { (int) at::Backend::SparseCUDA } |]

bSparseHIP :: Backend
bSparseHIP :: Backend
bSparseHIP = [C.pure| int { (int) at::Backend::SparseHIP } |]

bXLA :: Backend
bXLA :: Backend
bXLA = [C.pure| int { (int) at::Backend::XLA } |]

bUndefined :: Backend
bUndefined :: Backend
bUndefined = [C.pure| int { (int) at::Backend::Undefined } |]

bNumOptions :: Backend
bNumOptions :: Backend
bNumOptions = [C.pure| int { (int) at::Backend::NumOptions } |]

kStrided :: Layout
kStrided :: ScalarType
kStrided = [C.pure| int8_t { (int8_t) at::kStrided } |]

kSparse :: Layout
kSparse :: ScalarType
kSparse = [C.pure| int8_t { (int8_t) at::kSparse } |]

kMkldnn :: Layout
kMkldnn :: ScalarType
kMkldnn = [C.pure| int8_t { (int8_t) at::kMkldnn } |]