{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.Device where

import Data.Int (Int16)
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import GHC.TypeLits (KnownNat, Nat, SomeNat (..), natVal, someNatVal)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
import qualified Torch.Internal.Managed.Cast as ATen ()

-- | Data type to represent compute devices.
data DeviceType (deviceId :: Type) where
  -- | The tensor is stored in the CPU's memory.
  CPU :: forall deviceId. DeviceType deviceId
  -- | The tensor is stored the memory of the GPU with ID 'deviceId'.
  CUDA :: forall deviceId. deviceId -> DeviceType deviceId
  deriving (DeviceType deviceId -> DeviceType deviceId -> Bool
forall deviceId.
Eq deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c/= :: forall deviceId.
Eq deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
== :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c== :: forall deviceId.
Eq deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
Eq, DeviceType deviceId -> DeviceType deviceId -> Bool
DeviceType deviceId -> DeviceType deviceId -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {deviceId}. Ord deviceId => Eq (DeviceType deviceId)
forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Ordering
forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
min :: DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
$cmin :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
max :: DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
$cmax :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> DeviceType deviceId
>= :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c>= :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
> :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c> :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
<= :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c<= :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
< :: DeviceType deviceId -> DeviceType deviceId -> Bool
$c< :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Bool
compare :: DeviceType deviceId -> DeviceType deviceId -> Ordering
$ccompare :: forall deviceId.
Ord deviceId =>
DeviceType deviceId -> DeviceType deviceId -> Ordering
Ord, Int -> DeviceType deviceId -> ShowS
forall deviceId.
Show deviceId =>
Int -> DeviceType deviceId -> ShowS
forall deviceId. Show deviceId => [DeviceType deviceId] -> ShowS
forall deviceId. Show deviceId => DeviceType deviceId -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DeviceType deviceId] -> ShowS
$cshowList :: forall deviceId. Show deviceId => [DeviceType deviceId] -> ShowS
show :: DeviceType deviceId -> String
$cshow :: forall deviceId. Show deviceId => DeviceType deviceId -> String
showsPrec :: Int -> DeviceType deviceId -> ShowS
$cshowsPrec :: forall deviceId.
Show deviceId =>
Int -> DeviceType deviceId -> ShowS
Show)

data SDeviceType (deviceType :: DeviceType Nat) where
  SCPU :: SDeviceType 'CPU
  SCUDA :: forall deviceId. KnownNat deviceId => SDeviceType ('CUDA deviceId)

deriving stock instance Show (SDeviceType (deviceType :: DeviceType Nat))

type instance Sing = SDeviceType

instance SingI ('CPU :: DeviceType Nat) where
  sing :: Sing 'CPU
sing = SDeviceType 'CPU
SCPU

instance KnownNat deviceId => SingI ('CUDA deviceId) where
  sing :: Sing ('CUDA deviceId)
sing = forall (deviceType :: Nat).
KnownNat deviceType =>
SDeviceType ('CUDA deviceType)
SCUDA @deviceId

type family CUDAF (deviceType :: DeviceType Nat) :: Nat where
  CUDAF ('CUDA deviceId) = deviceId

instance SingKind (DeviceType Nat) where
  type Demote (DeviceType Nat) = DeviceType Int16
  fromSing :: forall (a :: DeviceType Nat). Sing a -> Demote (DeviceType Nat)
fromSing Sing a
SDeviceType a
SCPU = forall deviceId. DeviceType deviceId
CPU
  fromSing (Sing a
SDeviceType a
SCUDA :: Sing deviceId) = forall deviceId. deviceId -> DeviceType deviceId
CUDA forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(CUDAF deviceId)
  toSing :: Demote (DeviceType Nat) -> SomeSing (DeviceType Nat)
toSing Demote (DeviceType Nat)
DeviceType Int16
CPU = forall k (a :: k). Sing a -> SomeSing k
SomeSing SDeviceType 'CPU
SCPU
  toSing (CUDA Int16
deviceId) = case Integer -> Maybe SomeNat
someNatVal (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceId) of
    Just (SomeNat (Proxy n
_ :: Proxy deviceId)) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (deviceType :: Nat).
KnownNat deviceType =>
SDeviceType ('CUDA deviceType)
SCUDA @deviceId)

class KnownDeviceType (deviceType :: DeviceType Nat) where
  deviceTypeVal :: DeviceType Int16

instance KnownDeviceType 'CPU where
  deviceTypeVal :: DeviceType Int16
deviceTypeVal = forall deviceId. DeviceType deviceId
CPU

instance (KnownNat deviceId) => KnownDeviceType ('CUDA deviceId) where
  deviceTypeVal :: DeviceType Int16
deviceTypeVal = forall deviceId. deviceId -> DeviceType deviceId
CUDA (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @deviceId)

-- | Data type to represent whether or not the compute device is checked, that is, known to the compiler.
data Device (deviceType :: Type) where
  -- | The compute device is unknown to the compiler.
  UncheckedDevice :: forall deviceType. Device deviceType
  -- | The compute device is known to the compiler.
  Device :: forall deviceType. deviceType -> Device deviceType
  deriving (Int -> Device deviceType -> ShowS
forall deviceType.
Show deviceType =>
Int -> Device deviceType -> ShowS
forall deviceType. Show deviceType => [Device deviceType] -> ShowS
forall deviceType. Show deviceType => Device deviceType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Device deviceType] -> ShowS
$cshowList :: forall deviceType. Show deviceType => [Device deviceType] -> ShowS
show :: Device deviceType -> String
$cshow :: forall deviceType. Show deviceType => Device deviceType -> String
showsPrec :: Int -> Device deviceType -> ShowS
$cshowsPrec :: forall deviceType.
Show deviceType =>
Int -> Device deviceType -> ShowS
Show)

data SDevice (deviceType :: Device (DeviceType Nat)) where
  SUncheckedDevice :: DeviceType Int16 -> SDevice 'UncheckedDevice
  SDevice :: forall deviceType. SDeviceType deviceType -> SDevice ('Device deviceType)

deriving stock instance Show (SDevice (device :: Device (DeviceType Nat)))

type instance Sing = SDevice

instance SingI deviceType => SingI ('Device (deviceType :: DeviceType Nat)) where
  sing :: Sing ('Device deviceType)
sing = forall (deviceType :: DeviceType Nat).
SDeviceType deviceType -> SDevice ('Device deviceType)
SDevice forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @deviceType

instance SingKind (Device (DeviceType Nat)) where
  type Demote (Device (DeviceType Nat)) = IsChecked (DeviceType Int16)
  fromSing :: forall (a :: Device (DeviceType Nat)).
Sing a -> Demote (Device (DeviceType Nat))
fromSing (SUncheckedDevice DeviceType Int16
deviceType) = forall a. a -> IsChecked a
Unchecked DeviceType Int16
deviceType
  fromSing (SDevice SDeviceType deviceType
deviceType) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDeviceType deviceType
deviceType
  toSing :: Demote (Device (DeviceType Nat))
-> SomeSing (Device (DeviceType Nat))
toSing (Unchecked DeviceType Int16
deviceType) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall a b. (a -> b) -> a -> b
$ DeviceType Int16
deviceType
  toSing (Checked DeviceType Int16
deviceType) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing DeviceType Int16
deviceType forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (deviceType :: DeviceType Nat).
SDeviceType deviceType -> SDevice ('Device deviceType)
SDevice

class KnownDevice (device :: Device (DeviceType Nat)) where
  deviceVal :: Device (DeviceType Int16)

instance KnownDevice 'UncheckedDevice where
  deviceVal :: Device (DeviceType Int16)
deviceVal = forall deviceType. Device deviceType
UncheckedDevice

instance (KnownDeviceType deviceType) => KnownDevice ('Device deviceType) where
  deviceVal :: Device (DeviceType Int16)
deviceVal = forall deviceType. deviceType -> Device deviceType
Device (forall (deviceType :: DeviceType Nat).
KnownDeviceType deviceType =>
DeviceType Int16
deviceTypeVal @deviceType)

-- >>> :kind! GetDevices ('Device ('CUDA 0))
-- GetDevices ('Device ('CUDA 0)) :: [Device (DeviceType Nat)]
-- = '[ 'Device ('CUDA 0)]
-- >>> :kind! GetDevices '[ 'Device 'CPU, 'Device ('CUDA 0)]
-- GetDevices '[ 'Device 'CPU, 'Device ('CUDA 0)] :: [Device
--                                                      (DeviceType Nat)]
-- = '[ 'Device 'CPU, 'Device ('CUDA 0)]
-- >>> :kind! GetDevices ('Just ('Device ('CUDA 0)))
-- GetDevices ('Just ('Device ('CUDA 0))) :: [Device (DeviceType Nat)]
-- = '[ 'Device ('CUDA 0)]
type GetDevices :: k -> [Device (DeviceType Nat)]
type family GetDevices f where
  GetDevices (a :: Device (DeviceType Nat)) = '[a]
  GetDevices (f g) = Concat (GetDevices f) (GetDevices g)
  GetDevices _ = '[]