{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.Scalar where

import Foreign.ForeignPtr
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import Torch.Internal.Managed.Cast
import qualified Torch.Internal.Managed.Type.Scalar as ATen
import qualified Torch.Internal.Type as ATen

instance Castable Float (ForeignPtr ATen.Scalar) where
  cast :: forall r. Float -> (ForeignPtr Scalar -> IO r) -> IO r
cast Float
x ForeignPtr Scalar -> IO r
f = CFloat -> IO (ForeignPtr Scalar)
ATen.newScalar_f (forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Float -> IO r) -> IO r
uncast ForeignPtr Scalar
x Float -> IO r
f = forall a. HasCallStack => a
undefined

instance Castable Double (ForeignPtr ATen.Scalar) where
  cast :: forall r. Double -> (ForeignPtr Scalar -> IO r) -> IO r
cast Double
x ForeignPtr Scalar -> IO r
f = CDouble -> IO (ForeignPtr Scalar)
ATen.newScalar_d (forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Double -> IO r) -> IO r
uncast ForeignPtr Scalar
x Double -> IO r
f = forall a. HasCallStack => a
undefined

instance Castable Int (ForeignPtr ATen.Scalar) where
  cast :: forall r. Int -> (ForeignPtr Scalar -> IO r) -> IO r
cast Int
x ForeignPtr Scalar -> IO r
f = CInt -> IO (ForeignPtr Scalar)
ATen.newScalar_i (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Int -> IO r) -> IO r
uncast ForeignPtr Scalar
x Int -> IO r
f = forall a. HasCallStack => a
undefined

instance Castable Bool (ForeignPtr ATen.Scalar) where
  cast :: forall r. Bool -> (ForeignPtr Scalar -> IO r) -> IO r
cast Bool
x ForeignPtr Scalar -> IO r
f = CBool -> IO (ForeignPtr Scalar)
ATen.newScalar_b (if Bool
x then CBool
1 else CBool
0) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Bool -> IO r) -> IO r
uncast ForeignPtr Scalar
x Bool -> IO r
f = forall a. HasCallStack => a
undefined

class (Castable a (ForeignPtr ATen.Scalar)) => Scalar a

instance Scalar Float

instance Scalar Double

instance Scalar Int

instance Scalar Bool