{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Torch.GraduallyTyped.NN.Functional.Dropout where
import Control.Monad.Catch (MonadThrow)
import Foreign.ForeignPtr (ForeignPtr)
import Torch.GraduallyTyped.Device (DeviceType (..))
import Torch.GraduallyTyped.Random (Generator, SGetGeneratorDevice, sForwardWithGenerator)
import Torch.GraduallyTyped.Tensor.Type (SGetDevice (..), Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.Internal.Cast (cast3)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen (_fused_dropout_tdG)
import qualified Torch.Internal.Managed.Type.Tuple as ATen ()
import qualified Torch.Internal.Type as ATen (Tensor)
import Unsafe.Coerce (unsafeCoerce)
dropout ::
forall gradient layout device dataType shape generatorDevice m.
(SGetDevice device, SGetGeneratorDevice generatorDevice, MonadThrow m) =>
Double ->
Tensor gradient layout device dataType shape ->
Generator generatorDevice ->
m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
dropout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetDevice device, SGetGeneratorDevice generatorDevice,
MonadThrow m) =>
Double
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
dropout Double
p Tensor gradient layout device dataType shape
tensor Generator generatorDevice
g =
case forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> DeviceType Int16
getDeviceType Tensor gradient layout device dataType shape
tensor of
DeviceType Int16
CPU ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall a b. a -> b
unsafeCoerce
@(Tensor gradient layout device dataType shape)
@(Tensor gradient layout (device <+> generatorDevice) dataType shape)
Tensor gradient layout device dataType shape
tensor,
forall a b. a -> b
unsafeCoerce
@(Generator generatorDevice)
@(Generator (device <+> generatorDevice))
Generator generatorDevice
g
)
CUDA Int16
_ -> forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> (ForeignPtr Tensor
-> ForeignPtr Generator -> IO (ForeignPtr Tensor))
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sForwardWithGenerator Tensor gradient layout device dataType shape
tensor Generator generatorDevice
g forall a b. (a -> b) -> a -> b
$
\ForeignPtr Tensor
tPtr ForeignPtr Generator
genPtr -> do
(ForeignPtr Tensor
t :: ForeignPtr ATen.Tensor, ForeignPtr Tensor
_ :: ForeignPtr ATen.Tensor) <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> CDouble
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen._fused_dropout_tdG ForeignPtr Tensor
tPtr (Double
1 forall a. Num a => a -> a -> a
- Double
p) ForeignPtr Generator
genPtr
forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr Tensor
t