{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.GraduallyTyped.Internal.TensorOptions where

import Data.Singletons (SingKind (..))
import Foreign.ForeignPtr (ForeignPtr)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (SDataType)
import Torch.GraduallyTyped.Device (DeviceType (..), SDevice)
import Torch.GraduallyTyped.Layout (SLayout)
import Torch.GraduallyTyped.Prelude (forgetIsChecked)
import Torch.GraduallyTyped.RequiresGradient (RequiresGradient (..), SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), SShape)
import Torch.Internal.Cast (cast1, cast2)
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen (kCPU)
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen

newtype TensorOptions = TensorOptions (ForeignPtr ATen.TensorOptions)

instance Castable TensorOptions (ForeignPtr ATen.TensorOptions) where
  cast :: forall r.
TensorOptions -> (ForeignPtr TensorOptions -> IO r) -> IO r
cast (TensorOptions ForeignPtr TensorOptions
opts) ForeignPtr TensorOptions -> IO r
f = ForeignPtr TensorOptions -> IO r
f ForeignPtr TensorOptions
opts
  uncast :: forall r.
ForeignPtr TensorOptions -> (TensorOptions -> IO r) -> IO r
uncast ForeignPtr TensorOptions
opts TensorOptions -> IO r
f = TensorOptions -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorOptions -> TensorOptions
TensorOptions ForeignPtr TensorOptions
opts

tensorOptions ::
  forall gradient layout device dataType.
  SGradient gradient ->
  SLayout layout ->
  SDevice device ->
  SDataType dataType ->
  TensorOptions
tensorOptions :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  TensorOptions
opts :: TensorOptions <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ScalarType -> IO (ForeignPtr TensorOptions)
ATen.newTensorOptions_s DType
dType
  TensorOptions
opts :: TensorOptions <- let b :: Bool
b = RequiresGradient
requiresGradient forall a. Eq a => a -> a -> Bool
== RequiresGradient
WithGradient in forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> CBool -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_requires_grad_b TensorOptions
opts Bool
b
  TensorOptions
opts :: TensorOptions <- forall {a} {y} {x1}.
(Castable a (ForeignPtr TensorOptions),
 Castable y (ForeignPtr TensorOptions), Castable x1 Int16) =>
DeviceType x1 -> a -> IO y
withDevice DeviceType Int16
deviceType TensorOptions
opts
  TensorOptions
opts :: TensorOptions <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions
-> ScalarType -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_layout_L TensorOptions
opts LayoutType
layoutType
  forall (m :: * -> *) a. Monad m => a -> m a
return TensorOptions
opts
  where
    requiresGradient :: RequiresGradient
requiresGradient = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SGradient gradient
gradient
    layoutType :: LayoutType
layoutType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SLayout layout
layout
    deviceType :: DeviceType Int16
deviceType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDevice device
device
    dType :: DType
dType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDataType dataType
dataType

    withDevice :: DeviceType x1 -> a -> IO y
withDevice DeviceType x1
CPU a
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D a
opts Int16
ATen.kCPU
    withDevice (CUDA x1
deviceIndex) a
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s a
opts x1
deviceIndex

tensorDims ::
  forall shape.
  SShape shape ->
  [Dim String Integer]
tensorDims :: forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size))
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing