{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.Layout where

import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen

data Layout = Strided | Sparse | Mkldnn
  deriving (Layout -> Layout -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Layout -> Layout -> Bool
$c/= :: Layout -> Layout -> Bool
== :: Layout -> Layout -> Bool
$c== :: Layout -> Layout -> Bool
Eq, Int -> Layout -> ShowS
[Layout] -> ShowS
Layout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Layout] -> ShowS
$cshowList :: [Layout] -> ShowS
show :: Layout -> String
$cshow :: Layout -> String
showsPrec :: Int -> Layout -> ShowS
$cshowsPrec :: Int -> Layout -> ShowS
Show)

instance Castable Layout ATen.Layout where
  cast :: forall r. Layout -> (Layout -> IO r) -> IO r
cast Layout
Strided Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kStrided
  cast Layout
Sparse Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kSparse
  cast Layout
Mkldnn Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kMkldnn

  uncast :: forall r. Layout -> (Layout -> IO r) -> IO r
uncast Layout
x Layout -> IO r
f
    | Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kStrided = Layout -> IO r
f Layout
Strided
    | Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kSparse = Layout -> IO r
f Layout
Sparse
    | Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kMkldnn = Layout -> IO r
f Layout
Mkldnn