{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.GraduallyTyped.Scalar where

import Foreign.ForeignPtr (ForeignPtr)
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Managed.Type.Scalar as ATen
import qualified Torch.Internal.Type as ATen

instance Castable Float (ForeignPtr ATen.Scalar) where
  cast :: forall r. Float -> (ForeignPtr Scalar -> IO r) -> IO r
cast Float
x ForeignPtr Scalar -> IO r
f = CDouble -> IO (ForeignPtr Scalar)
ATen.newScalar_d (forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Float -> IO r) -> IO r
uncast ForeignPtr Scalar
_x Float -> IO r
_f = forall a. HasCallStack => a
undefined

instance Castable Double (ForeignPtr ATen.Scalar) where
  cast :: forall r. Double -> (ForeignPtr Scalar -> IO r) -> IO r
cast Double
x ForeignPtr Scalar -> IO r
f = CDouble -> IO (ForeignPtr Scalar)
ATen.newScalar_d (forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Double -> IO r) -> IO r
uncast ForeignPtr Scalar
_x Double -> IO r
_f = forall a. HasCallStack => a
undefined

instance Castable Int (ForeignPtr ATen.Scalar) where
  cast :: forall r. Int -> (ForeignPtr Scalar -> IO r) -> IO r
cast Int
x ForeignPtr Scalar -> IO r
f = CInt -> IO (ForeignPtr Scalar)
ATen.newScalar_i (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Int -> IO r) -> IO r
uncast ForeignPtr Scalar
_x Int -> IO r
_f = forall a. HasCallStack => a
undefined

instance Castable Integer (ForeignPtr ATen.Scalar) where
  cast :: forall r. Integer -> (ForeignPtr Scalar -> IO r) -> IO r
cast Integer
x ForeignPtr Scalar -> IO r
f = CInt -> IO (ForeignPtr Scalar)
ATen.newScalar_i (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ForeignPtr Scalar -> IO r
f
  uncast :: forall r. ForeignPtr Scalar -> (Integer -> IO r) -> IO r
uncast ForeignPtr Scalar
_x Integer -> IO r
_f = forall a. HasCallStack => a
undefined

class (Castable a (ForeignPtr ATen.Scalar)) => Scalar a

instance Scalar Float

instance Scalar Double

instance Scalar Int

instance Scalar Integer