{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.Device where

import qualified Data.Int as I
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen

data DeviceType = CPU | CUDA
  deriving (DeviceType -> DeviceType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DeviceType -> DeviceType -> Bool
$c/= :: DeviceType -> DeviceType -> Bool
== :: DeviceType -> DeviceType -> Bool
$c== :: DeviceType -> DeviceType -> Bool
Eq, Eq DeviceType
DeviceType -> DeviceType -> Bool
DeviceType -> DeviceType -> Ordering
DeviceType -> DeviceType -> DeviceType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: DeviceType -> DeviceType -> DeviceType
$cmin :: DeviceType -> DeviceType -> DeviceType
max :: DeviceType -> DeviceType -> DeviceType
$cmax :: DeviceType -> DeviceType -> DeviceType
>= :: DeviceType -> DeviceType -> Bool
$c>= :: DeviceType -> DeviceType -> Bool
> :: DeviceType -> DeviceType -> Bool
$c> :: DeviceType -> DeviceType -> Bool
<= :: DeviceType -> DeviceType -> Bool
$c<= :: DeviceType -> DeviceType -> Bool
< :: DeviceType -> DeviceType -> Bool
$c< :: DeviceType -> DeviceType -> Bool
compare :: DeviceType -> DeviceType -> Ordering
$ccompare :: DeviceType -> DeviceType -> Ordering
Ord, Int -> DeviceType -> ShowS
[DeviceType] -> ShowS
DeviceType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DeviceType] -> ShowS
$cshowList :: [DeviceType] -> ShowS
show :: DeviceType -> String
$cshow :: DeviceType -> String
showsPrec :: Int -> DeviceType -> ShowS
$cshowsPrec :: Int -> DeviceType -> ShowS
Show)

data Device = Device {Device -> DeviceType
deviceType :: DeviceType, Device -> Int16
deviceIndex :: I.Int16}
  deriving (Device -> Device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Device -> Device -> Bool
$c/= :: Device -> Device -> Bool
== :: Device -> Device -> Bool
$c== :: Device -> Device -> Bool
Eq, Eq Device
Device -> Device -> Bool
Device -> Device -> Ordering
Device -> Device -> Device
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Device -> Device -> Device
$cmin :: Device -> Device -> Device
max :: Device -> Device -> Device
$cmax :: Device -> Device -> Device
>= :: Device -> Device -> Bool
$c>= :: Device -> Device -> Bool
> :: Device -> Device -> Bool
$c> :: Device -> Device -> Bool
<= :: Device -> Device -> Bool
$c<= :: Device -> Device -> Bool
< :: Device -> Device -> Bool
$c< :: Device -> Device -> Bool
compare :: Device -> Device -> Ordering
$ccompare :: Device -> Device -> Ordering
Ord, Int -> Device -> ShowS
[Device] -> ShowS
Device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Device] -> ShowS
$cshowList :: [Device] -> ShowS
show :: Device -> String
$cshow :: Device -> String
showsPrec :: Int -> Device -> ShowS
$cshowsPrec :: Int -> Device -> ShowS
Show)

instance Castable DeviceType ATen.DeviceType where
  cast :: forall r. DeviceType -> (Int16 -> IO r) -> IO r
cast DeviceType
CPU Int16 -> IO r
f = Int16 -> IO r
f Int16
ATen.kCPU
  cast DeviceType
CUDA Int16 -> IO r
f = Int16 -> IO r
f Int16
ATen.kCUDA

  uncast :: forall r. Int16 -> (DeviceType -> IO r) -> IO r
uncast Int16
x DeviceType -> IO r
f
    | Int16
x forall a. Eq a => a -> a -> Bool
== Int16
ATen.kCPU = DeviceType -> IO r
f DeviceType
CPU
    | Int16
x forall a. Eq a => a -> a -> Bool
== Int16
ATen.kCUDA = DeviceType -> IO r
f DeviceType
CUDA