{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
module Torch.TensorOptions where
import Data.Int
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Layout
type ATenTensorOptions = ForeignPtr ATen.TensorOptions
newtype TensorOptions = TensorOptions ATenTensorOptions deriving (Int -> TensorOptions -> ShowS
[TensorOptions] -> ShowS
TensorOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [TensorOptions] -> ShowS
$cshowList :: [TensorOptions] -> ShowS
show :: TensorOptions -> [Char]
$cshow :: TensorOptions -> [Char]
showsPrec :: Int -> TensorOptions -> ShowS
$cshowsPrec :: Int -> TensorOptions -> ShowS
Show)
instance Castable TensorOptions ATenTensorOptions where
cast :: forall r.
TensorOptions -> (ForeignPtr TensorOptions -> IO r) -> IO r
cast (TensorOptions ForeignPtr TensorOptions
aten_opts) ForeignPtr TensorOptions -> IO r
f = ForeignPtr TensorOptions -> IO r
f ForeignPtr TensorOptions
aten_opts
uncast :: forall r.
ForeignPtr TensorOptions -> (TensorOptions -> IO r) -> IO r
uncast ForeignPtr TensorOptions
aten_opts TensorOptions -> IO r
f = TensorOptions -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorOptions -> TensorOptions
TensorOptions ForeignPtr TensorOptions
aten_opts
defaultOpts :: TensorOptions
defaultOpts :: TensorOptions
defaultOpts =
ForeignPtr TensorOptions -> TensorOptions
TensorOptions forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ Layout -> IO (ForeignPtr TensorOptions)
ATen.newTensorOptions_s Layout
ATen.kFloat
withDType :: DType -> TensorOptions -> TensorOptions
withDType :: DType -> TensorOptions -> TensorOptions
withDType DType
dtype TensorOptions
opts =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_dtype_s TensorOptions
opts DType
dtype
withDevice :: Device -> TensorOptions -> TensorOptions
withDevice :: Device -> TensorOptions -> TensorOptions
withDevice Device {Int16
DeviceType
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
deviceIndex :: Int16
deviceType :: DeviceType
..} TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Bool
hasCUDA <- forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA
DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
deviceType Int16
deviceIndex Bool
hasCUDA TensorOptions
opts
where
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di
withDevice' ::
DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' :: DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
CPU Int16
0 Bool
False TensorOptions
opts = forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts
withDevice' DeviceType
CPU Int16
0 Bool
True TensorOptions
opts = forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU
withDevice' DeviceType
CUDA Int16
di Bool
True TensorOptions
opts | Int16
di forall a. Ord a => a -> a -> Bool
>= Int16
0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di
withDevice' DeviceType
dt Int16
di Bool
_ TensorOptions
_ =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"cannot move tensor to \"" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show DeviceType
dt forall a. Semigroup a => a -> a -> a
<> [Char]
":" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int16
di forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout Layout
layout TensorOptions
opts =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_layout_L TensorOptions
opts Layout
layout