{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Torch.GraduallyTyped.Internal.TensorOptions where import Data.Singletons (SingKind (..)) import Foreign.ForeignPtr (ForeignPtr) import System.IO.Unsafe (unsafePerformIO) import Torch.GraduallyTyped.DType (SDataType) import Torch.GraduallyTyped.Device (DeviceType (..), SDevice) import Torch.GraduallyTyped.Layout (SLayout) import Torch.GraduallyTyped.Prelude (forgetIsChecked) import Torch.GraduallyTyped.RequiresGradient (RequiresGradient (..), SGradient) import Torch.GraduallyTyped.Shape.Type (Dim (..), SShape) import Torch.Internal.Cast (cast1, cast2) import Torch.Internal.Class (Castable (..)) import qualified Torch.Internal.Const as ATen (kCPU) import qualified Torch.Internal.Managed.Type.TensorOptions as ATen import qualified Torch.Internal.Type as ATen newtype TensorOptions = TensorOptions (ForeignPtr ATen.TensorOptions) instance Castable TensorOptions (ForeignPtr ATen.TensorOptions) where cast :: forall r. TensorOptions -> (ForeignPtr TensorOptions -> IO r) -> IO r cast (TensorOptions ForeignPtr TensorOptions opts) ForeignPtr TensorOptions -> IO r f = ForeignPtr TensorOptions -> IO r f ForeignPtr TensorOptions opts uncast :: forall r. ForeignPtr TensorOptions -> (TensorOptions -> IO r) -> IO r uncast ForeignPtr TensorOptions opts TensorOptions -> IO r f = TensorOptions -> IO r f forall a b. (a -> b) -> a -> b $ ForeignPtr TensorOptions -> TensorOptions TensorOptions ForeignPtr TensorOptions opts tensorOptions :: forall gradient layout device dataType. SGradient gradient -> SLayout layout -> SDevice device -> SDataType dataType -> TensorOptions tensorOptions :: forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType). SGradient gradient -> SLayout layout -> SDevice device -> SDataType dataType -> TensorOptions tensorOptions SGradient gradient gradient SLayout layout layout SDevice device device SDataType dataType dataType = forall a. IO a -> a unsafePerformIO forall a b. (a -> b) -> a -> b $ do TensorOptions opts :: TensorOptions <- forall a ca y cy. (Castable a ca, Castable y cy) => (ca -> IO cy) -> a -> IO y cast1 ScalarType -> IO (ForeignPtr TensorOptions) ATen.newTensorOptions_s DType dType TensorOptions opts :: TensorOptions <- let b :: Bool b = RequiresGradient requiresGradient forall a. Eq a => a -> a -> Bool == RequiresGradient WithGradient in forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr TensorOptions -> CBool -> IO (ForeignPtr TensorOptions) ATen.tensorOptions_requires_grad_b TensorOptions opts Bool b TensorOptions opts :: TensorOptions <- forall {a} {y} {x1}. (Castable a (ForeignPtr TensorOptions), Castable y (ForeignPtr TensorOptions), Castable x1 Int16) => DeviceType x1 -> a -> IO y withDevice DeviceType Int16 deviceType TensorOptions opts TensorOptions opts :: TensorOptions <- forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr TensorOptions -> ScalarType -> IO (ForeignPtr TensorOptions) ATen.tensorOptions_layout_L TensorOptions opts LayoutType layoutType forall (m :: * -> *) a. Monad m => a -> m a return TensorOptions opts where requiresGradient :: RequiresGradient requiresGradient = forall a. IsChecked a -> a forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c . forall k (a :: k). SingKind k => Sing a -> Demote k fromSing forall a b. (a -> b) -> a -> b $ SGradient gradient gradient layoutType :: LayoutType layoutType = forall a. IsChecked a -> a forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c . forall k (a :: k). SingKind k => Sing a -> Demote k fromSing forall a b. (a -> b) -> a -> b $ SLayout layout layout deviceType :: DeviceType Int16 deviceType = forall a. IsChecked a -> a forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c . forall k (a :: k). SingKind k => Sing a -> Demote k fromSing forall a b. (a -> b) -> a -> b $ SDevice device device dType :: DType dType = forall a. IsChecked a -> a forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c . forall k (a :: k). SingKind k => Sing a -> Demote k fromSing forall a b. (a -> b) -> a -> b $ SDataType dataType dataType withDevice :: DeviceType x1 -> a -> IO y withDevice DeviceType x1 CPU a opts = forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions) ATen.tensorOptions_device_D a opts Int16 ATen.kCPU withDevice (CUDA x1 deviceIndex) a opts = forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions) ATen.tensorOptions_device_index_s a opts x1 deviceIndex tensorDims :: forall shape. SShape shape -> [Dim String Integer] tensorDims :: forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SShape shape -> [Dim String Integer] tensorDims = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b fmap (\(Dim IsChecked String name IsChecked Integer size) -> forall name size. name -> size -> Dim name size Dim (forall a. IsChecked a -> a forgetIsChecked IsChecked String name) (forall a. IsChecked a -> a forgetIsChecked IsChecked Integer size)) forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. IsChecked a -> a forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c . forall k (a :: k). SingKind k => Sing a -> Demote k fromSing