{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.DType where
import Data.Kind (Type)
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Data.Singletons.TH (genSingletons)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen
data DType
=
Bool
|
UInt8
|
Int8
|
Int16
|
Int32
|
Int64
|
Half
|
Float
|
Double
|
ComplexHalf
|
ComplexFloat
|
ComplexDouble
|
QInt8
|
QUInt8
|
QInt32
|
BFloat16
deriving (DType -> DType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DType -> DType -> Bool
$c/= :: DType -> DType -> Bool
== :: DType -> DType -> Bool
$c== :: DType -> DType -> Bool
Eq, Int -> DType -> ShowS
[DType] -> ShowS
DType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DType] -> ShowS
$cshowList :: [DType] -> ShowS
show :: DType -> String
$cshow :: DType -> String
showsPrec :: Int -> DType -> ShowS
$cshowsPrec :: Int -> DType -> ShowS
Show, ReadPrec [DType]
ReadPrec DType
Int -> ReadS DType
ReadS [DType]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [DType]
$creadListPrec :: ReadPrec [DType]
readPrec :: ReadPrec DType
$creadPrec :: ReadPrec DType
readList :: ReadS [DType]
$creadList :: ReadS [DType]
readsPrec :: Int -> ReadS DType
$creadsPrec :: Int -> ReadS DType
Read)
genSingletons [''DType]
deriving stock instance Show (SDType (dType :: DType))
instance Castable DType ATen.ScalarType where
cast :: forall r. DType -> (ScalarType -> IO r) -> IO r
cast DType
Bool ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kBool
cast DType
UInt8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kByte
cast DType
Int8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kChar
cast DType
Int16 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kShort
cast DType
Int32 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kInt
cast DType
Int64 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kLong
cast DType
Half ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kHalf
cast DType
Float ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kFloat
cast DType
Double ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kDouble
cast DType
ComplexHalf ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kComplexHalf
cast DType
ComplexFloat ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kComplexFloat
cast DType
ComplexDouble ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kComplexDouble
cast DType
QInt8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kQInt8
cast DType
QUInt8 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kQUInt8
cast DType
QInt32 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kQInt32
cast DType
BFloat16 ScalarType -> IO r
f = ScalarType -> IO r
f ScalarType
ATen.kBFloat16
uncast :: forall r. ScalarType -> (DType -> IO r) -> IO r
uncast ScalarType
x DType -> IO r
f
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kBool = DType -> IO r
f DType
Bool
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kByte = DType -> IO r
f DType
UInt8
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kChar = DType -> IO r
f DType
Int8
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kShort = DType -> IO r
f DType
Int16
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kInt = DType -> IO r
f DType
Int32
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kLong = DType -> IO r
f DType
Int64
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kHalf = DType -> IO r
f DType
Half
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kFloat = DType -> IO r
f DType
Float
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kDouble = DType -> IO r
f DType
Double
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kComplexHalf = DType -> IO r
f DType
ComplexHalf
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kComplexFloat = DType -> IO r
f DType
ComplexFloat
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kComplexDouble = DType -> IO r
f DType
ComplexDouble
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kQInt8 = DType -> IO r
f DType
QInt8
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kQUInt8 = DType -> IO r
f DType
QUInt8
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kQInt32 = DType -> IO r
f DType
QInt32
| ScalarType
x forall a. Eq a => a -> a -> Bool
== ScalarType
ATen.kBFloat16 = DType -> IO r
f DType
BFloat16
class KnownDType (dType :: DType) where
dTypeVal :: DType
instance KnownDType 'Bool where
dTypeVal :: DType
dTypeVal = DType
Bool
instance KnownDType 'UInt8 where
dTypeVal :: DType
dTypeVal = DType
UInt8
instance KnownDType 'Int8 where
dTypeVal :: DType
dTypeVal = DType
Int8
instance KnownDType 'Int16 where
dTypeVal :: DType
dTypeVal = DType
Int16
instance KnownDType 'Int32 where
dTypeVal :: DType
dTypeVal = DType
Int32
instance KnownDType 'Int64 where
dTypeVal :: DType
dTypeVal = DType
Int64
instance KnownDType 'Half where
dTypeVal :: DType
dTypeVal = DType
Half
instance KnownDType 'Float where
dTypeVal :: DType
dTypeVal = DType
Float
instance KnownDType 'Double where
dTypeVal :: DType
dTypeVal = DType
Double
data DataType (dType :: Type) where
UncheckedDataType :: forall dType. DataType dType
DataType :: forall dType. dType -> DataType dType
deriving (Int -> DataType dType -> ShowS
forall dType. Show dType => Int -> DataType dType -> ShowS
forall dType. Show dType => [DataType dType] -> ShowS
forall dType. Show dType => DataType dType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DataType dType] -> ShowS
$cshowList :: forall dType. Show dType => [DataType dType] -> ShowS
show :: DataType dType -> String
$cshow :: forall dType. Show dType => DataType dType -> String
showsPrec :: Int -> DataType dType -> ShowS
$cshowsPrec :: forall dType. Show dType => Int -> DataType dType -> ShowS
Show)
data SDataType (dataType :: DataType DType) where
SUncheckedDataType :: DType -> SDataType 'UncheckedDataType
SDataType :: forall dType. SDType dType -> SDataType ('DataType dType)
deriving stock instance Show (SDataType (dataType :: DataType DType))
type instance Sing = SDataType
instance SingI dType => SingI ('DataType (dType :: DType)) where
sing :: Sing ('DataType dType)
sing = forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dType
instance SingKind (DataType DType) where
type Demote (DataType DType) = IsChecked DType
fromSing :: forall (a :: DataType DType). Sing a -> Demote (DataType DType)
fromSing (SUncheckedDataType DType
dType) = forall a. a -> IsChecked a
Unchecked DType
dType
fromSing (SDataType SDType dType
dType) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDType dType
dType
toSing :: Demote (DataType DType) -> SomeSing (DataType DType)
toSing (Unchecked DType
dType) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> SDataType 'UncheckedDataType
SUncheckedDataType forall a b. (a -> b) -> a -> b
$ DType
dType
toSing (Checked DType
dType) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing DType
dType forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType
class KnownDataType (dataType :: DataType DType) where
dataTypeVal :: DataType DType
instance KnownDataType 'UncheckedDataType where
dataTypeVal :: DataType DType
dataTypeVal = forall dType. DataType dType
UncheckedDataType
instance
(KnownDType dType) =>
KnownDataType ('DataType dType)
where
dataTypeVal :: DataType DType
dataTypeVal = forall dType. dType -> DataType dType
DataType (forall (dType :: DType). KnownDType dType => DType
dTypeVal @dType)
type GetDataTypes :: k -> [DataType DType]
type family GetDataTypes f where
GetDataTypes (a :: DataType DType) = '[a]
GetDataTypes (f g) = Concat (GetDataTypes f) (GetDataTypes g)
GetDataTypes _ = '[]