{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.TensorOptions where

import Data.Int
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Layout

type ATenTensorOptions = ForeignPtr ATen.TensorOptions

newtype TensorOptions = TensorOptions ATenTensorOptions deriving (Int -> TensorOptions -> ShowS
[TensorOptions] -> ShowS
TensorOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [TensorOptions] -> ShowS
$cshowList :: [TensorOptions] -> ShowS
show :: TensorOptions -> [Char]
$cshow :: TensorOptions -> [Char]
showsPrec :: Int -> TensorOptions -> ShowS
$cshowsPrec :: Int -> TensorOptions -> ShowS
Show)

instance Castable TensorOptions ATenTensorOptions where
  cast :: forall r.
TensorOptions -> (ForeignPtr TensorOptions -> IO r) -> IO r
cast (TensorOptions ForeignPtr TensorOptions
aten_opts) ForeignPtr TensorOptions -> IO r
f = ForeignPtr TensorOptions -> IO r
f ForeignPtr TensorOptions
aten_opts
  uncast :: forall r.
ForeignPtr TensorOptions -> (TensorOptions -> IO r) -> IO r
uncast ForeignPtr TensorOptions
aten_opts TensorOptions -> IO r
f = TensorOptions -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorOptions -> TensorOptions
TensorOptions ForeignPtr TensorOptions
aten_opts

defaultOpts :: TensorOptions
defaultOpts :: TensorOptions
defaultOpts =
  ForeignPtr TensorOptions -> TensorOptions
TensorOptions forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ Layout -> IO (ForeignPtr TensorOptions)
ATen.newTensorOptions_s Layout
ATen.kFloat

withDType :: DType -> TensorOptions -> TensorOptions
withDType :: DType -> TensorOptions -> TensorOptions
withDType DType
dtype TensorOptions
opts =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_dtype_s TensorOptions
opts DType
dtype

withDevice :: Device -> TensorOptions -> TensorOptions
withDevice :: Device -> TensorOptions -> TensorOptions
withDevice Device {Int16
DeviceType
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
deviceIndex :: Int16
deviceType :: DeviceType
..} TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  Bool
hasCUDA <- forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA
  DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
deviceType Int16
deviceIndex Bool
hasCUDA TensorOptions
opts
  where
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di -- careful, this somehow implies deviceType = CUDA
    withDevice' ::
      DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
    withDevice' :: DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
CPU Int16
0 Bool
False TensorOptions
opts = forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts
    withDevice' DeviceType
CPU Int16
0 Bool
True TensorOptions
opts = forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU
    withDevice' DeviceType
CUDA Int16
di Bool
True TensorOptions
opts | Int16
di forall a. Ord a => a -> a -> Bool
>= Int16
0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di
    withDevice' DeviceType
dt Int16
di Bool
_ TensorOptions
_ =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"cannot move tensor to \"" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show DeviceType
dt forall a. Semigroup a => a -> a -> a
<> [Char]
":" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int16
di forall a. Semigroup a => a -> a -> a
<> [Char]
"\""

withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout Layout
layout TensorOptions
opts =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_layout_L TensorOptions
opts Layout
layout