{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.Random where

import Control.Concurrent.STM (TVar, atomically, newTVarIO, readTVar, writeTVar)
import Control.Concurrent.STM.TVar (readTVarIO)
import Control.Monad.Catch (MonadThrow)
import Data.Int (Int16)
import Data.Proxy (Proxy (Proxy))
import Data.Singletons (SingI (..), SingKind (..))
import Data.Word (Word64)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.TypeLits (KnownNat, Nat, natVal)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..), SDeviceType (..))
import Torch.GraduallyTyped.Internal.TensorOptions (TensorOptions (..), tensorDims, tensorOptions)
import Torch.GraduallyTyped.Prelude (forgetIsChecked, pattern Demoted')
import Torch.GraduallyTyped.Shape.Type (Dim)
import Torch.GraduallyTyped.Tensor.Type (Tensor (UnsafeTensor), TensorSpec (..), gitHubErrorMsg)
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.Internal.Cast (cast0, cast1, cast2, cast4)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.Generator as ATen
import qualified Torch.Internal.Type as ATen

newtype Generator (device :: Device (DeviceType Nat)) where
  UnsafeGenerator ::
    forall device.
    TVar (Either (SDevice device, Word64) (ForeignPtr ATen.Generator)) ->
    Generator device

type role Generator nominal

sMkGenerator ::
  forall m device.
  MonadThrow m =>
  -- | generator device singleton
  SDevice device ->
  -- | initial seed
  Word64 ->
  -- | returned generator
  m (Generator device)
sMkGenerator :: forall (m :: * -> *) (device :: Device (DeviceType Nat)).
MonadThrow m =>
SDevice device -> Word64 -> m (Generator device)
sMkGenerator SDevice device
device Word64
seed =
  forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
    let deviceType :: DeviceType Int16
deviceType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDevice device
device
    ForeignPtr Generator
genPtr <- case DeviceType Int16
deviceType of
      DeviceType Int16
CPU -> Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
      CUDA Int16
deviceId -> do
        ForeignPtr Generator
genPtr <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceId)
        ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
genPtr Word64
seed
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr Generator
genPtr
    forall (device :: Device (DeviceType Nat)).
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> Generator device
UnsafeGenerator forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)

mkGenerator ::
  forall m device.
  (SingI device, MonadThrow m) =>
  -- | initial seed
  Word64 ->
  -- | returned generator
  m (Generator device)
mkGenerator :: forall (m :: * -> *) (device :: Device (DeviceType Nat)).
(SingI device, MonadThrow m) =>
Word64 -> m (Generator device)
mkGenerator = forall (m :: * -> *) (device :: Device (DeviceType Nat)).
MonadThrow m =>
SDevice device -> Word64 -> m (Generator device)
sMkGenerator (forall {k} (a :: k). SingI a => Sing a
sing @device)

sSetGeneratorDevice ::
  forall m generatorDevice' generatorDevice.
  MonadThrow m =>
  SDevice generatorDevice' ->
  Generator generatorDevice ->
  m (Generator generatorDevice')
sSetGeneratorDevice :: forall (m :: * -> *) (generatorDevice' :: Device (DeviceType Nat))
       (generatorDevice :: Device (DeviceType Nat)).
MonadThrow m =>
SDevice generatorDevice'
-> Generator generatorDevice -> m (Generator generatorDevice')
sSetGeneratorDevice = forall a. HasCallStack => a
undefined

setGeneratorDevice ::
  forall m generatorDevice' generatorDevice.
  (SingI generatorDevice', MonadThrow m) =>
  Generator generatorDevice ->
  m (Generator generatorDevice')
setGeneratorDevice :: forall (m :: * -> *) (generatorDevice' :: Device (DeviceType Nat))
       (generatorDevice :: Device (DeviceType Nat)).
(SingI generatorDevice', MonadThrow m) =>
Generator generatorDevice -> m (Generator generatorDevice')
setGeneratorDevice = forall (m :: * -> *) (generatorDevice' :: Device (DeviceType Nat))
       (generatorDevice :: Device (DeviceType Nat)).
MonadThrow m =>
SDevice generatorDevice'
-> Generator generatorDevice -> m (Generator generatorDevice')
sSetGeneratorDevice (forall {k} (a :: k). SingI a => Sing a
sing @generatorDevice')

class SGetGeneratorDevice (device :: Device (DeviceType Nat)) where
  sGetGenPtrDevice ::
    ForeignPtr ATen.Generator ->
    SDevice device

instance SGetGeneratorDevice 'UncheckedDevice where
  sGetGenPtrDevice :: ForeignPtr Generator -> SDevice 'UncheckedDevice
sGetGenPtrDevice ForeignPtr Generator
genPtr
    | forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cuda ForeignPtr Generator
genPtr) =
      case forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Int64
ATen.generator_get_device ForeignPtr Generator
genPtr) :: Int of
        Int
deviceIndex -> DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall deviceId. deviceId -> DeviceType deviceId
CUDA forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
deviceIndex
    | Bool
otherwise = DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall deviceId. DeviceType deviceId
CPU

instance SGetGeneratorDevice ('Device 'CPU) where
  sGetGenPtrDevice :: ForeignPtr Generator -> SDevice ('Device 'CPU)
sGetGenPtrDevice ForeignPtr Generator
genPtr
    | forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cuda ForeignPtr Generator
genPtr) =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [Char]
"The generator should be on CPU but is on CUDA. "
          forall a. Semigroup a => a -> a -> a
<> [Char]
gitHubErrorMsg
    | Bool
otherwise = forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU

instance KnownNat deviceIndex => SGetGeneratorDevice ('Device ('CUDA deviceIndex)) where
  sGetGenPtrDevice :: ForeignPtr Generator -> SDevice ('Device ('CUDA deviceIndex))
sGetGenPtrDevice ForeignPtr Generator
genPtr
    | forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cuda ForeignPtr Generator
genPtr) =
      case forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Int64
ATen.generator_get_device ForeignPtr Generator
genPtr) :: Int of
        Int
deviceIndex
          | Int
deviceIndex forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @deviceIndex)) -> forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice forall (deviceId :: Nat).
KnownNat deviceId =>
SDeviceType ('CUDA deviceId)
SCUDA
          | Bool
otherwise ->
            forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
              [Char]
"The generator should be on CUDA device "
                forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @deviceIndex))
                forall a. Semigroup a => a -> a -> a
<> [Char]
" but is on device "
                forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
deviceIndex
                forall a. Semigroup a => a -> a -> a
<> [Char]
". "
                forall a. Semigroup a => a -> a -> a
<> [Char]
gitHubErrorMsg
    | Bool
otherwise =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [Char]
"The generator should be on CUDA but is on CPU. "
          forall a. Semigroup a => a -> a -> a
<> [Char]
gitHubErrorMsg

sGetGeneratorDevice ::
  forall device.
  SGetGeneratorDevice device =>
  -- | input
  Generator device ->
  -- | compute device of the input generator
  SDevice device
sGetGeneratorDevice :: forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
Generator device -> SDevice device
sGetGeneratorDevice (UnsafeGenerator TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
tvar) = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
  Either (SDevice device, Word64) (ForeignPtr Generator)
state <- forall a. TVar a -> STM a
readTVar TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
tvar
  case Either (SDevice device, Word64) (ForeignPtr Generator)
state of
    Left (SDevice device
device, Word64
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SDevice device
device
    Right ForeignPtr Generator
genPtr -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
ForeignPtr Generator -> SDevice device
sGetGenPtrDevice ForeignPtr Generator
genPtr

getGeneratorDeviceType ::
  forall device.
  SGetGeneratorDevice device =>
  -- | input
  Generator device ->
  -- | compute device of the input generator
  DeviceType Int16
getGeneratorDeviceType :: forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
Generator device -> DeviceType Int16
getGeneratorDeviceType Generator device
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
Generator device -> SDevice device
sGetGeneratorDevice Generator device
tensor

getGenPtr ::
  SGetGeneratorDevice device =>
  TVar (Either (SDevice device, Word64) (ForeignPtr ATen.Generator)) ->
  IO (ForeignPtr ATen.Generator)
getGenPtr :: forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> IO (ForeignPtr Generator)
getGenPtr TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
tvar = do
  Either (SDevice device, Word64) (ForeignPtr Generator)
state <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
    Either (SDevice device, Word64) (ForeignPtr Generator)
state <- forall a. TVar a -> STM a
readTVar TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
tvar
    case Either (SDevice device, Word64) (ForeignPtr Generator)
state of
      Right ForeignPtr Generator
genPtr -> do
        let !device :: SDevice device
device = forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
ForeignPtr Generator -> SDevice device
sGetGenPtrDevice ForeignPtr Generator
genPtr
            !seed :: Word64
seed = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Word64
ATen.generator_current_seed ForeignPtr Generator
genPtr
        forall a. TVar a -> a -> STM ()
writeTVar TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
tvar forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left (SDevice device
device, Word64
seed)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
      Left (SDevice device, Word64)
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Either (SDevice device, Word64) (ForeignPtr Generator)
state
  case Either (SDevice device, Word64) (ForeignPtr Generator)
state of
    Right ForeignPtr Generator
genPtr -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr Generator
genPtr
    Left (SDevice device
device, Word64
seed) -> do
      let deviceType :: DeviceType Int16
deviceType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDevice device
device
      case DeviceType Int16
deviceType of
        DeviceType Int16
CPU -> Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
        CUDA Int16
deviceId -> do
          ForeignPtr Generator
genPtr <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceId)
          ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
genPtr Word64
seed
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr Generator
genPtr

sCreateWithGenerator ::
  forall m gradient layout device dataType shape generatorDevice.
  (SGetGeneratorDevice generatorDevice, MonadThrow m) =>
  TensorSpec gradient layout device dataType shape ->
  Generator generatorDevice ->
  (ForeignPtr ATen.TensorOptions -> [Dim String Integer] -> ForeignPtr ATen.Generator -> IO (ForeignPtr ATen.Tensor)) ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sCreateWithGenerator :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> (ForeignPtr TensorOptions
    -> [Dim [Char] Integer]
    -> ForeignPtr Generator
    -> IO (ForeignPtr Tensor))
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sCreateWithGenerator TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
..} (UnsafeGenerator TVar
  (Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar) ForeignPtr TensorOptions
-> [Dim [Char] Integer]
-> ForeignPtr Generator
-> IO (ForeignPtr Tensor)
rawCreateFn =
  forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr Generator
genPtr <- forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> IO (ForeignPtr Generator)
getGenPtr TVar
  (Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar
    let TensorOptions ForeignPtr TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
tsGradient SLayout layout
tsLayout SDevice device
tsDevice SDataType dataType
tsDataType
        dims :: [Dim [Char] Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim [Char] Integer]
tensorDims SShape shape
tsShape
    ForeignPtr Tensor
tPtr <- ForeignPtr TensorOptions
-> [Dim [Char] Integer]
-> ForeignPtr Generator
-> IO (ForeignPtr Tensor)
rawCreateFn ForeignPtr TensorOptions
opts [Dim [Char] Integer]
dims ForeignPtr Generator
genPtr
    TVar
  (Either
     (SDevice (Unify (Device (DeviceType Nat)) device generatorDevice),
      Word64)
     (ForeignPtr Generator))
g <- forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor ForeignPtr Tensor
tPtr, forall (device :: Device (DeviceType Nat)).
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> Generator device
UnsafeGenerator TVar
  (Either
     (SDevice (Unify (Device (DeviceType Nat)) device generatorDevice),
      Word64)
     (ForeignPtr Generator))
g)

sForwardWithGenerator ::
  forall m gradient layout device dataType shape generatorDevice.
  (SGetGeneratorDevice generatorDevice, MonadThrow m) =>
  Tensor gradient layout device dataType shape ->
  Generator generatorDevice ->
  (ForeignPtr ATen.Tensor -> ForeignPtr ATen.Generator -> IO (ForeignPtr ATen.Tensor)) ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sForwardWithGenerator :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> (ForeignPtr Tensor
    -> ForeignPtr Generator -> IO (ForeignPtr Tensor))
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sForwardWithGenerator (UnsafeTensor ForeignPtr Tensor
tPtr) (UnsafeGenerator TVar
  (Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar) ForeignPtr Tensor -> ForeignPtr Generator -> IO (ForeignPtr Tensor)
rawForwardFn =
  forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr Generator
genPtr <- forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> IO (ForeignPtr Generator)
getGenPtr TVar
  (Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar
    ForeignPtr Tensor
tPtr' <- ForeignPtr Tensor -> ForeignPtr Generator -> IO (ForeignPtr Tensor)
rawForwardFn ForeignPtr Tensor
tPtr ForeignPtr Generator
genPtr
    TVar
  (Either
     (SDevice (Unify (Device (DeviceType Nat)) device generatorDevice),
      Word64)
     (ForeignPtr Generator))
g <- forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor ForeignPtr Tensor
tPtr', forall (device :: Device (DeviceType Nat)).
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> Generator device
UnsafeGenerator TVar
  (Either
     (SDevice (Unify (Device (DeviceType Nat)) device generatorDevice),
      Word64)
     (ForeignPtr Generator))
g)