{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.DType where

import Data.Complex
import qualified Numeric.Half as N
import Data.Int
import Data.Reflection
import Data.Word
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen

data DType
  = -- | Bool
    Bool
  | -- | Byte
    UInt8
  | -- | Char
    Int8
  | -- | Short
    Int16
  | -- | Int
    Int32
  | -- | Long
    Int64
  | -- | Half
    Half
  | -- | Float
    Float
  | -- | Double
    Double
  | -- | ComplexHalf
    ComplexHalf
  | -- | ComplexFloat
    ComplexFloat
  | -- | ComplexDouble
    ComplexDouble
  | -- | QInt8
    QInt8
  | -- | QUInt8
    QUInt8
  | -- | QInt32
    QInt32
  | -- | BFloat16
    BFloat16
  deriving (DType -> DType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DType -> DType -> Bool
$c/= :: DType -> DType -> Bool
== :: DType -> DType -> Bool
$c== :: DType -> DType -> Bool
Eq, Int -> DType -> ShowS
[DType] -> ShowS
DType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DType] -> ShowS
$cshowList :: [DType] -> ShowS
show :: DType -> String
$cshow :: DType -> String
showsPrec :: Int -> DType -> ShowS
$cshowsPrec :: Int -> DType -> ShowS
Show, ReadPrec [DType]
ReadPrec DType
Int -> ReadS DType
ReadS [DType]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [DType]
$creadListPrec :: ReadPrec [DType]
readPrec :: ReadPrec DType
$creadPrec :: ReadPrec DType
readList :: ReadS [DType]
$creadList :: ReadS [DType]
readsPrec :: Int -> ReadS DType
$creadsPrec :: Int -> ReadS DType
Read)

instance Reifies Bool DType where
  reflect :: forall (proxy :: * -> *). proxy Bool -> DType
reflect proxy Bool
_ = DType
Bool

instance Reifies 'Bool DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Bool -> DType
reflect proxy 'Bool
_ = DType
Bool

instance Reifies Word8 DType where
  reflect :: forall (proxy :: * -> *). proxy Word8 -> DType
reflect proxy Word8
_ = DType
UInt8

instance Reifies Int8 DType where
  reflect :: forall (proxy :: * -> *). proxy Int8 -> DType
reflect proxy Int8
_ = DType
Int8

instance Reifies 'Int8 DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Int8 -> DType
reflect proxy 'Int8
_ = DType
Int8

instance Reifies Int16 DType where
  reflect :: forall (proxy :: * -> *). proxy Int16 -> DType
reflect proxy Int16
_ = DType
Int16

instance Reifies 'Int16 DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Int16 -> DType
reflect proxy 'Int16
_ = DType
Int16

instance Reifies Int32 DType where
  reflect :: forall (proxy :: * -> *). proxy Int32 -> DType
reflect proxy Int32
_ = DType
Int32

instance Reifies 'Int32 DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Int32 -> DType
reflect proxy 'Int32
_ = DType
Int32

instance Reifies Int DType where
  reflect :: forall (proxy :: * -> *). proxy Int -> DType
reflect proxy Int
_ = DType
Int64

instance Reifies Int64 DType where
  reflect :: forall (proxy :: * -> *). proxy Int64 -> DType
reflect proxy Int64
_ = DType
Int64

instance Reifies 'Int64 DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Int64 -> DType
reflect proxy 'Int64
_ = DType
Int64

instance Reifies N.Half DType where
  reflect :: forall (proxy :: * -> *). proxy Half -> DType
reflect proxy Half
_ = DType
Half

instance Reifies 'Half DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Half -> DType
reflect proxy 'Half
_ = DType
Half

instance Reifies Float DType where
  reflect :: forall (proxy :: * -> *). proxy Float -> DType
reflect proxy Float
_ = DType
Float

instance Reifies 'Float DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Float -> DType
reflect proxy 'Float
_ = DType
Float

instance Reifies Double DType where
  reflect :: forall (proxy :: * -> *). proxy Double -> DType
reflect proxy Double
_ = DType
Double

instance Reifies 'Double DType where
  reflect :: forall (proxy :: DType -> *). proxy 'Double -> DType
reflect proxy 'Double
_ = DType
Double

instance Reifies (Complex N.Half) DType where
  reflect :: forall (proxy :: * -> *). proxy (Complex Half) -> DType
reflect proxy (Complex Half)
_ = DType
ComplexHalf

instance Reifies 'ComplexHalf DType where
  reflect :: forall (proxy :: DType -> *). proxy 'ComplexHalf -> DType
reflect proxy 'ComplexHalf
_ = DType
ComplexHalf

instance Reifies (Complex Float) DType where
  reflect :: forall (proxy :: * -> *). proxy (Complex Float) -> DType
reflect proxy (Complex Float)
_ = DType
ComplexFloat

instance Reifies 'ComplexFloat DType where
  reflect :: forall (proxy :: DType -> *). proxy 'ComplexFloat -> DType
reflect proxy 'ComplexFloat
_ = DType
ComplexFloat

instance Reifies (Complex Double) DType where
  reflect :: forall (proxy :: * -> *). proxy (Complex Double) -> DType
reflect proxy (Complex Double)
_ = DType
ComplexDouble

instance Reifies 'ComplexDouble DType where
  reflect :: forall (proxy :: DType -> *). proxy 'ComplexDouble -> DType
reflect proxy 'ComplexDouble
_ = DType
ComplexDouble

instance Castable DType ATen.ScalarType where
  cast :: forall r. DType -> (Int8 -> IO r) -> IO r
cast DType
Bool Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kBool
  cast DType
UInt8 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kByte
  cast DType
Int8 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kChar
  cast DType
Int16 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kShort
  cast DType
Int32 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kInt
  cast DType
Int64 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kLong
  cast DType
Half Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kHalf
  cast DType
Float Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kFloat
  cast DType
Double Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kDouble
  cast DType
ComplexHalf Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kComplexHalf
  cast DType
ComplexFloat Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kComplexFloat
  cast DType
ComplexDouble Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kComplexDouble
  cast DType
QInt8 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kQInt8
  cast DType
QUInt8 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kQUInt8
  cast DType
QInt32 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kQInt32
  cast DType
BFloat16 Int8 -> IO r
f = Int8 -> IO r
f Int8
ATen.kBFloat16

  uncast :: forall r. Int8 -> (DType -> IO r) -> IO r
uncast Int8
x DType -> IO r
f
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kBool = DType -> IO r
f DType
Bool
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kByte = DType -> IO r
f DType
UInt8
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kChar = DType -> IO r
f DType
Int8
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kShort = DType -> IO r
f DType
Int16
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kInt = DType -> IO r
f DType
Int32
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kLong = DType -> IO r
f DType
Int64
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kHalf = DType -> IO r
f DType
Half
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kFloat = DType -> IO r
f DType
Float
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kDouble = DType -> IO r
f DType
Double
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kComplexHalf = DType -> IO r
f DType
ComplexHalf
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kComplexFloat = DType -> IO r
f DType
ComplexFloat
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kComplexDouble = DType -> IO r
f DType
ComplexDouble
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kQInt8 = DType -> IO r
f DType
QInt8
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kQUInt8 = DType -> IO r
f DType
QUInt8
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kQInt32 = DType -> IO r
f DType
QInt32
    | Int8
x forall a. Eq a => a -> a -> Bool
== Int8
ATen.kBFloat16 = DType -> IO r
f DType
BFloat16

isIntegral :: DType -> Bool
isIntegral :: DType -> Bool
isIntegral DType
Bool = Bool
True
isIntegral DType
UInt8 = Bool
True
isIntegral DType
Int8 = Bool
True
isIntegral DType
Int16 = Bool
True
isIntegral DType
Int32 = Bool
True
isIntegral DType
Int64 = Bool
True
isIntegral DType
Half = Bool
False
isIntegral DType
Float = Bool
False
isIntegral DType
Double = Bool
False
isIntegral DType
ComplexHalf = Bool
False
isIntegral DType
ComplexFloat = Bool
False
isIntegral DType
ComplexDouble = Bool
False
isIntegral DType
QInt8 = Bool
False
isIntegral DType
QUInt8 = Bool
False
isIntegral DType
QInt32 = Bool
False
isIntegral DType
BFloat16 = Bool
False

isComplex :: DType -> Bool
isComplex :: DType -> Bool
isComplex DType
Bool = Bool
False
isComplex DType
UInt8 = Bool
False
isComplex DType
Int8 = Bool
False
isComplex DType
Int16 = Bool
False
isComplex DType
Int32 = Bool
False
isComplex DType
Int64 = Bool
False
isComplex DType
Half = Bool
False
isComplex DType
Float = Bool
False
isComplex DType
Double = Bool
False
isComplex DType
ComplexHalf = Bool
True
isComplex DType
ComplexFloat = Bool
True
isComplex DType
ComplexDouble = Bool
True
isComplex DType
QInt8 = Bool
False
isComplex DType
QUInt8 = Bool
False
isComplex DType
QInt32 = Bool
False
isComplex DType
BFloat16 = Bool
False

byteLength :: DType -> Int
byteLength :: DType -> Int
byteLength DType
dtype =
  case DType
dtype of
    DType
Bool -> Int
1
    DType
UInt8 -> Int
1
    DType
Int8 -> Int
1
    DType
Int16 -> Int
2
    DType
Int32 -> Int
4
    DType
Int64 -> Int
8
    DType
Half -> Int
2
    DType
Float -> Int
4
    DType
Double -> Int
8
    DType
ComplexHalf -> Int
4
    DType
ComplexFloat -> Int
8
    DType
ComplexDouble -> Int
16
    DType
QInt8 -> Int
1
    DType
QUInt8 -> Int
1
    DType
QInt32 -> Int
4
    DType
BFloat16 -> Int
2