{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.NN.Functional.Dropout where

import Control.Monad.Catch (MonadThrow)
import Foreign.ForeignPtr (ForeignPtr)
import Torch.GraduallyTyped.Device (DeviceType (..))
import Torch.GraduallyTyped.Random (Generator, SGetGeneratorDevice, sForwardWithGenerator)
import Torch.GraduallyTyped.Tensor.Type (SGetDevice (..), Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.Internal.Cast (cast3)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen (_fused_dropout_tdG)
import qualified Torch.Internal.Managed.Type.Tuple as ATen ()
import qualified Torch.Internal.Type as ATen (Tensor)
import Unsafe.Coerce (unsafeCoerce)

-- $setup
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped

-- | Dropout randomly zeroes some of the elements of
-- the input tensor with probability 'p' using samples from a Bernoulli distribution.
dropout ::
  forall gradient layout device dataType shape generatorDevice m.
  (SGetDevice device, SGetGeneratorDevice generatorDevice, MonadThrow m) =>
  -- | probability of an element to be zeroed
  Double ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | generator
  Generator generatorDevice ->
  -- | output
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
dropout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetDevice device, SGetGeneratorDevice generatorDevice,
 MonadThrow m) =>
Double
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
dropout Double
p Tensor gradient layout device dataType shape
tensor Generator generatorDevice
g =
  case forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> DeviceType Int16
getDeviceType Tensor gradient layout device dataType shape
tensor of
    DeviceType Int16
CPU ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall a b. a -> b
unsafeCoerce
            @(Tensor gradient layout device dataType shape)
            @(Tensor gradient layout (device <+> generatorDevice) dataType shape)
            Tensor gradient layout device dataType shape
tensor,
          forall a b. a -> b
unsafeCoerce
            @(Generator generatorDevice)
            @(Generator (device <+> generatorDevice))
            Generator generatorDevice
g
        )
    CUDA Int16
_ -> forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> (ForeignPtr Tensor
    -> ForeignPtr Generator -> IO (ForeignPtr Tensor))
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sForwardWithGenerator Tensor gradient layout device dataType shape
tensor Generator generatorDevice
g forall a b. (a -> b) -> a -> b
$
      \ForeignPtr Tensor
tPtr ForeignPtr Generator
genPtr -> do
        (ForeignPtr Tensor
t :: ForeignPtr ATen.Tensor, ForeignPtr Tensor
_ :: ForeignPtr ATen.Tensor) <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> CDouble
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen._fused_dropout_tdG ForeignPtr Tensor
tPtr (Double
1 forall a. Num a => a -> a -> a
- Double
p) ForeignPtr Generator
genPtr
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr Tensor
t