{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.DType where

import Data.Kind (Type)
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Data.Singletons.TH (genSingletons)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen

data DType
  = -- | Bool
    Bool
  | -- | Byte
    UInt8
  | -- | Char
    Int8
  | -- | Short
    Int16
  | -- | Int
    Int32
  | -- | Long
    Int64
  | -- | Half
    Half
  | -- | Float
    Float
  | -- | Double
    Double
  | -- | ComplexHalf
    ComplexHalf
  | -- | ComplexFloat
    ComplexFloat
  | -- | ComplexDouble
    ComplexDouble
  | -- | QInt8
    QInt8
  | -- | QUInt8
    QUInt8
  | -- | QInt32
    QInt32
  | -- | BFloat16
    BFloat16
  deriving (DType -> DType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DType -> DType -> Bool
$c/= :: DType -> DType -> Bool
== :: DType -> DType -> Bool
$c== :: DType -> DType -> Bool
Eq, Int -> DType -> ShowS
[DType] -> ShowS
DType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DType] -> ShowS
$cshowList :: [DType] -> ShowS
show :: DType -> String
$cshow :: DType -> String
showsPrec :: Int -> DType -> ShowS
$cshowsPrec :: Int -> DType -> ShowS
Show, ReadPrec [DType]
ReadPrec DType
Int -> ReadS DType
ReadS [DType]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [DType]
$creadListPrec :: ReadPrec [DType]
readPrec :: ReadPrec DType
$creadPrec :: ReadPrec DType
readList :: ReadS [DType]
$creadList :: ReadS [DType]
readsPrec :: Int -> ReadS DType
$creadsPrec :: Int -> ReadS DType
Read)

genSingletons [''DType]

deriving stock instance Show (SDType (dType :: DType))

instance Castable DType ATen.ScalarType where
  cast :: forall r. DType -> (ScalarType -> IO r) -> IO r
cast DType
Bool ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kBool
  cast DType
UInt8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kByte
  cast DType
Int8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kChar
  cast DType
Int16 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kShort
  cast DType
Int32 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kInt
  cast DType
Int64 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kLong
  cast DType
Half ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kHalf
  cast DType
Float ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kFloat
  cast DType
Double ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kDouble
  cast DType
ComplexHalf ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kComplexHalf
  cast DType
ComplexFloat ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kComplexFloat
  cast DType
ComplexDouble ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kComplexDouble
  cast DType
QInt8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kQInt8
  cast DType
QUInt8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kQUInt8
  cast DType
QInt32 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kQInt32
  cast DType
BFloat16 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kBFloat16

  uncast :: forall r. ScalarType -> (DType -> IO r) -> IO r
uncast ScalarType
x DType -> IO r
f
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kBool = DType -> IO r
f DType
Bool
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kByte = DType -> IO r
f DType
UInt8
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kChar = DType -> IO r
f DType
Int8
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kShort = DType -> IO r
f DType
Int16
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kInt = DType -> IO r
f DType
Int32
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kLong = DType -> IO r
f DType
Int64
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kHalf = DType -> IO r
f DType
Half
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kFloat = DType -> IO r
f DType
Float
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kDouble = DType -> IO r
f DType
Double
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kComplexHalf = DType -> IO r
f DType
ComplexHalf
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kComplexFloat = DType -> IO r
f DType
ComplexFloat
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kComplexDouble = DType -> IO r
f DType
ComplexDouble
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kQInt8 = DType -> IO r
f DType
QInt8
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kQUInt8 = DType -> IO r
f DType
QUInt8
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kQInt32 = DType -> IO r
f DType
QInt32
    | ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kBFloat16 = DType -> IO r
f DType
BFloat16

class KnownDType (dType :: DType) where
  dTypeVal :: DType

instance KnownDType 'Bool where
  dTypeVal :: DType
dTypeVal = DType
Bool

instance KnownDType 'UInt8 where
  dTypeVal :: DType
dTypeVal = DType
UInt8

instance KnownDType 'Int8 where
  dTypeVal :: DType
dTypeVal = DType
Int8

instance KnownDType 'Int16 where
  dTypeVal :: DType
dTypeVal = DType
Int16

instance KnownDType 'Int32 where
  dTypeVal :: DType
dTypeVal = DType
Int32

instance KnownDType 'Int64 where
  dTypeVal :: DType
dTypeVal = DType
Int64

instance KnownDType 'Half where
  dTypeVal :: DType
dTypeVal = DType
Half

instance KnownDType 'Float where
  dTypeVal :: DType
dTypeVal = DType
Float

instance KnownDType 'Double where
  dTypeVal :: DType
dTypeVal = DType
Double

-- | Data type to represent whether or not the tensor data type is checked, that is, known to the compiler.
data DataType (dType :: Type) where
  -- | The tensor data type is unknown to the compiler.
  UncheckedDataType :: forall dType. DataType dType
  -- | The tensor data type is known to the compiler.
  DataType :: forall dType. dType -> DataType dType
  deriving (Int -> DataType dType -> ShowS
forall dType. Show dType => Int -> DataType dType -> ShowS
forall dType. Show dType => [DataType dType] -> ShowS
forall dType. Show dType => DataType dType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DataType dType] -> ShowS
$cshowList :: forall dType. Show dType => [DataType dType] -> ShowS
show :: DataType dType -> String
$cshow :: forall dType. Show dType => DataType dType -> String
showsPrec :: Int -> DataType dType -> ShowS
$cshowsPrec :: forall dType. Show dType => Int -> DataType dType -> ShowS
Show)

data SDataType (dataType :: DataType DType) where
  SUncheckedDataType :: DType -> SDataType 'UncheckedDataType
  SDataType :: forall dType. SDType dType -> SDataType ('DataType dType)

deriving stock instance Show (SDataType (dataType :: DataType DType))

type instance Sing = SDataType

instance SingI dType => SingI ('DataType (dType :: DType)) where
  sing :: Sing ('DataType dType)
sing = forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dType

instance SingKind (DataType DType) where
  type Demote (DataType DType) = IsChecked DType
  fromSing :: forall (a :: DataType DType). Sing a -> Demote (DataType DType)
fromSing (SUncheckedDataType DType
dType) = forall a. a -> IsChecked a
Unchecked DType
dType
  fromSing (SDataType SDType dType
dType) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDType dType
dType
  toSing :: Demote (DataType DType) -> SomeSing (DataType DType)
toSing (Unchecked DType
dType) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> SDataType 'UncheckedDataType
SUncheckedDataType forall a b. (a -> b) -> a -> b
$ DType
dType
  toSing (Checked DType
dType) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing DType
dType forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType

class KnownDataType (dataType :: DataType DType) where
  dataTypeVal :: DataType DType

instance KnownDataType 'UncheckedDataType where
  dataTypeVal :: DataType DType
dataTypeVal = forall dType. DataType dType
UncheckedDataType

instance
  (KnownDType dType) =>
  KnownDataType ('DataType dType)
  where
  dataTypeVal :: DataType DType
dataTypeVal = forall dType. dType -> DataType dType
DataType (forall (dType :: DType). KnownDType dType => DType
dTypeVal @dType)

-- >>> :kind! GetDataTypes ('DataType 'Float)
-- GetDataTypes ('DataType 'Float) :: [DataType DType]
-- = '[ 'DataType 'Float]
-- >>> :kind! GetDataTypes '[ 'DataType 'Bool, 'DataType 'Float]
-- GetDataTypes '[ 'DataType 'Bool, 'DataType 'Float] :: [DataType
--                                                          DType]
-- = '[ 'DataType 'Bool, 'DataType 'Float]
-- >>> :kind! GetDataTypes ('Just ('DataType 'Bool))
-- GetDataTypes ('Just ('DataType 'Bool)) :: [DataType DType]
-- = '[ 'DataType 'Bool]
type GetDataTypes :: k -> [DataType DType]
type family GetDataTypes f where
  GetDataTypes (a :: DataType DType) = '[a]
  GetDataTypes (f g) = Concat (GetDataTypes f) (GetDataTypes g)
  GetDataTypes _ = '[]