{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.Backend where

import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen

data Backend = CPU | CUDA | HIP | SparseCPU | SparseCUDA | XLA
  deriving (Backend -> Backend -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Backend -> Backend -> Bool
$c/= :: Backend -> Backend -> Bool
== :: Backend -> Backend -> Bool
$c== :: Backend -> Backend -> Bool
Eq, Int -> Backend -> ShowS
[Backend] -> ShowS
Backend -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Backend] -> ShowS
$cshowList :: [Backend] -> ShowS
show :: Backend -> String
$cshow :: Backend -> String
showsPrec :: Int -> Backend -> ShowS
$cshowsPrec :: Int -> Backend -> ShowS
Show)

instance Castable Backend ATen.Backend where
  cast :: forall r. Backend -> (Backend -> IO r) -> IO r
cast Backend
CPU Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bCPU
  cast Backend
CUDA Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bCUDA
  cast Backend
HIP Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bHIP
  cast Backend
SparseCPU Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bSparseCPU
  cast Backend
SparseCUDA Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bSparseCUDA
  cast Backend
XLA Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bXLA

  uncast :: forall r. Backend -> (Backend -> IO r) -> IO r
uncast Backend
x Backend -> IO r
f
    | Backend
x forall a. Eq a => a -> a -> Bool
== Backend
ATen.bCPU = Backend -> IO r
f Backend
CPU
    | Backend
x forall a. Eq a => a -> a -> Bool
== Backend
ATen.bCUDA = Backend -> IO r
f Backend
CUDA
    | Backend
x forall a. Eq a => a -> a -> Bool
== Backend
ATen.bHIP = Backend -> IO r
f Backend
HIP
    | Backend
x forall a. Eq a => a -> a -> Bool
== Backend
ATen.bSparseCPU = Backend -> IO r
f Backend
SparseCPU
    | Backend
x forall a. Eq a => a -> a -> Bool
== Backend
ATen.bSparseCUDA = Backend -> IO r
f Backend
SparseCUDA
    | Backend
x forall a. Eq a => a -> a -> Bool
== Backend
ATen.bXLA = Backend -> IO r
f Backend
XLA