{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.Device where
import Data.Int (Int16)
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import GHC.TypeLits (KnownNat, Nat, SomeNat (..), natVal, someNatVal)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
import qualified Torch.Internal.Managed.Cast as ATen ()
data DeviceType (deviceId :: Type) where
CPU :: forall deviceId. DeviceType deviceId
CUDA :: forall deviceId. deviceId -> DeviceType deviceId
deriving (DeviceType deviceId -> DeviceType deviceId -> Bool
forall deviceId.
Eq deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c/= :: forall deviceId.
Eq deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
== :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c== :: forall deviceId.
Eq deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
Eq, DeviceType deviceId -> DeviceType deviceId -> Bool
DeviceType deviceId -> DeviceType deviceId -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {deviceId}. Ord deviceId => Eq (DeviceType deviceId)
forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Ordering
forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
min :: DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
$cmin :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
max :: DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
$cmax :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
>= :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c>= :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
> :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c> :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
<= :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c<= :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
< :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c< :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
compare :: DeviceType deviceId -> DeviceType deviceId -> Ordering
$ccompare :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Ordering
Ord, Int -> DeviceType deviceId -> ShowS
forall deviceId.
Show deviceId =>
Int -> DeviceType deviceId -> ShowS
forall deviceId. Show deviceId => [DeviceType deviceId] -> ShowS
forall deviceId. Show deviceId => DeviceType deviceId -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DeviceType deviceId] -> ShowS
$cshowList :: forall deviceId. Show deviceId => [DeviceType deviceId] -> ShowS
show :: DeviceType deviceId -> String
$cshow :: forall deviceId. Show deviceId => DeviceType deviceId -> String
showsPrec :: Int -> DeviceType deviceId -> ShowS
$cshowsPrec :: forall deviceId.
Show deviceId =>
Int -> DeviceType deviceId -> ShowS
Show)
data SDeviceType (deviceType :: DeviceType Nat) where
SCPU :: SDeviceType 'CPU
SCUDA :: forall deviceId. KnownNat deviceId => SDeviceType ('CUDA deviceId)
deriving stock instance Show (SDeviceType (deviceType :: DeviceType Nat))
type instance Sing = SDeviceType
instance SingI ('CPU :: DeviceType Nat) where
sing :: Sing 'CPU
sing = SDeviceType 'CPU
SCPU
instance KnownNat deviceId => SingI ('CUDA deviceId) where
sing :: Sing ('CUDA deviceId)
sing = forall (deviceType :: Nat).
KnownNat deviceType =>
SDeviceType ('CUDA deviceType)
SCUDA @deviceId
type family CUDAF (deviceType :: DeviceType Nat) :: Nat where
CUDAF ('CUDA deviceId) = deviceId
instance SingKind (DeviceType Nat) where
type Demote (DeviceType Nat) = DeviceType Int16
fromSing :: forall (a :: DeviceType Nat). Sing a -> Demote (DeviceType Nat)
fromSing Sing a
SDeviceType a
SCPU = forall deviceId. DeviceType deviceId
CPU
fromSing (Sing a
SDeviceType a
SCUDA :: Sing deviceId) = forall deviceId. deviceId -> DeviceType deviceId
CUDA forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(CUDAF deviceId)
toSing :: Demote (DeviceType Nat) -> SomeSing (DeviceType Nat)
toSing Demote (DeviceType Nat)
DeviceType Int16
CPU = forall k (a :: k). Sing a -> SomeSing k
SomeSing SDeviceType 'CPU
SCPU
toSing (CUDA Int16
deviceId) = case Integer -> Maybe SomeNat
someNatVal (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceId) of
Just (SomeNat (Proxy n
_ :: Proxy deviceId)) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (deviceType :: Nat).
KnownNat deviceType =>
SDeviceType ('CUDA deviceType)
SCUDA @deviceId)
class KnownDeviceType (deviceType :: DeviceType Nat) where
deviceTypeVal :: DeviceType Int16
instance KnownDeviceType 'CPU where
deviceTypeVal :: DeviceType Int16
deviceTypeVal = forall deviceId. DeviceType deviceId
CPU
instance (KnownNat deviceId) => KnownDeviceType ('CUDA deviceId) where
deviceTypeVal :: DeviceType Int16
deviceTypeVal = forall deviceId. deviceId -> DeviceType deviceId
CUDA (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @deviceId)
data Device (deviceType :: Type) where
UncheckedDevice :: forall deviceType. Device deviceType
Device :: forall deviceType. deviceType -> Device deviceType
deriving (Int -> Device deviceType -> ShowS
forall deviceType.
Show deviceType =>
Int -> Device deviceType -> ShowS
forall deviceType. Show deviceType => [Device deviceType] -> ShowS
forall deviceType. Show deviceType => Device deviceType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Device deviceType] -> ShowS
$cshowList :: forall deviceType. Show deviceType => [Device deviceType] -> ShowS
show :: Device deviceType -> String
$cshow :: forall deviceType. Show deviceType => Device deviceType -> String
showsPrec :: Int -> Device deviceType -> ShowS
$cshowsPrec :: forall deviceType.
Show deviceType =>
Int -> Device deviceType -> ShowS
Show)
data SDevice (deviceType :: Device (DeviceType Nat)) where
SUncheckedDevice :: DeviceType Int16 -> SDevice 'UncheckedDevice
SDevice :: forall deviceType. SDeviceType deviceType -> SDevice ('Device deviceType)
deriving stock instance Show (SDevice (device :: Device (DeviceType Nat)))
type instance Sing = SDevice
instance SingI deviceType => SingI ('Device (deviceType :: DeviceType Nat)) where
sing :: Sing ('Device deviceType)
sing = forall (deviceType :: DeviceType Nat).
SDeviceType deviceType -> SDevice ('Device deviceType)
SDevice forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @deviceType
instance SingKind (Device (DeviceType Nat)) where
type Demote (Device (DeviceType Nat)) = IsChecked (DeviceType Int16)
fromSing :: forall (a :: Device (DeviceType Nat)).
Sing a -> Demote (Device (DeviceType Nat))
fromSing (SUncheckedDevice DeviceType Int16
deviceType) = forall a. a -> IsChecked a
Unchecked DeviceType Int16
deviceType
fromSing (SDevice SDeviceType deviceType
deviceType) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDeviceType deviceType
deviceType
toSing :: Demote (Device (DeviceType Nat))
-> SomeSing (Device (DeviceType Nat))
toSing (Unchecked DeviceType Int16
deviceType) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall a b. (a -> b) -> a -> b
$ DeviceType Int16
deviceType
toSing (Checked DeviceType Int16
deviceType) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing DeviceType Int16
deviceType forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (deviceType :: DeviceType Nat).
SDeviceType deviceType -> SDevice ('Device deviceType)
SDevice
class KnownDevice (device :: Device (DeviceType Nat)) where
deviceVal :: Device (DeviceType Int16)
instance KnownDevice 'UncheckedDevice where
deviceVal :: Device (DeviceType Int16)
deviceVal = forall deviceType. Device deviceType
UncheckedDevice
instance (KnownDeviceType deviceType) => KnownDevice ('Device deviceType) where
deviceVal :: Device (DeviceType Int16)
deviceVal = forall deviceType. deviceType -> Device deviceType
Device (forall (deviceType :: DeviceType Nat).
KnownDeviceType deviceType =>
DeviceType Int16
deviceTypeVal @deviceType)
type GetDevices :: k -> [Device (DeviceType Nat)]
type family GetDevices f where
GetDevices (a :: Device (DeviceType Nat)) = '[a]
GetDevices (f g) = Concat (GetDevices f) (GetDevices g)
GetDevices _ = '[]