{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} module Torch.Scalar where import Foreign.ForeignPtr import Torch.Internal.Cast import Torch.Internal.Class (Castable (..)) import qualified Torch.Internal.Const as ATen import Torch.Internal.Managed.Cast import qualified Torch.Internal.Managed.Type.Scalar as ATen import qualified Torch.Internal.Type as ATen instance Castable Float (ForeignPtr ATen.Scalar) where cast :: forall r. Float -> (ForeignPtr Scalar -> IO r) -> IO r cast Float x ForeignPtr Scalar -> IO r f = CFloat -> IO (ForeignPtr Scalar) ATen.newScalar_f (forall a b. (Real a, Fractional b) => a -> b realToFrac Float x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= ForeignPtr Scalar -> IO r f uncast :: forall r. ForeignPtr Scalar -> (Float -> IO r) -> IO r uncast ForeignPtr Scalar x Float -> IO r f = forall a. HasCallStack => a undefined instance Castable Double (ForeignPtr ATen.Scalar) where cast :: forall r. Double -> (ForeignPtr Scalar -> IO r) -> IO r cast Double x ForeignPtr Scalar -> IO r f = CDouble -> IO (ForeignPtr Scalar) ATen.newScalar_d (forall a b. (Real a, Fractional b) => a -> b realToFrac Double x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= ForeignPtr Scalar -> IO r f uncast :: forall r. ForeignPtr Scalar -> (Double -> IO r) -> IO r uncast ForeignPtr Scalar x Double -> IO r f = forall a. HasCallStack => a undefined instance Castable Int (ForeignPtr ATen.Scalar) where cast :: forall r. Int -> (ForeignPtr Scalar -> IO r) -> IO r cast Int x ForeignPtr Scalar -> IO r f = CInt -> IO (ForeignPtr Scalar) ATen.newScalar_i (forall a b. (Integral a, Num b) => a -> b fromIntegral Int x) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= ForeignPtr Scalar -> IO r f uncast :: forall r. ForeignPtr Scalar -> (Int -> IO r) -> IO r uncast ForeignPtr Scalar x Int -> IO r f = forall a. HasCallStack => a undefined instance Castable Bool (ForeignPtr ATen.Scalar) where cast :: forall r. Bool -> (ForeignPtr Scalar -> IO r) -> IO r cast Bool x ForeignPtr Scalar -> IO r f = CBool -> IO (ForeignPtr Scalar) ATen.newScalar_b (if Bool x then CBool 1 else CBool 0) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= ForeignPtr Scalar -> IO r f uncast :: forall r. ForeignPtr Scalar -> (Bool -> IO r) -> IO r uncast ForeignPtr Scalar x Bool -> IO r f = forall a. HasCallStack => a undefined class (Castable a (ForeignPtr ATen.Scalar)) => Scalar a instance Scalar Float instance Scalar Double instance Scalar Int instance Scalar Bool