{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.Typed.NN.Dropout where

import GHC.Generics
import System.IO.Unsafe
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data DropoutSpec where
  DropoutSpec ::
    {DropoutSpec -> Double
dropoutProbSpec :: Double} ->
    DropoutSpec
  deriving (Int -> DropoutSpec -> ShowS
[DropoutSpec] -> ShowS
DropoutSpec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DropoutSpec] -> ShowS
$cshowList :: [DropoutSpec] -> ShowS
show :: DropoutSpec -> String
$cshow :: DropoutSpec -> String
showsPrec :: Int -> DropoutSpec -> ShowS
$cshowsPrec :: Int -> DropoutSpec -> ShowS
Show, DropoutSpec -> DropoutSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DropoutSpec -> DropoutSpec -> Bool
$c/= :: DropoutSpec -> DropoutSpec -> Bool
== :: DropoutSpec -> DropoutSpec -> Bool
$c== :: DropoutSpec -> DropoutSpec -> Bool
Eq)

data Dropout where
  Dropout ::
    {Dropout -> Double
dropoutProb :: Double} ->
    Dropout
  deriving (Int -> Dropout -> ShowS
[Dropout] -> ShowS
Dropout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Dropout] -> ShowS
$cshowList :: [Dropout] -> ShowS
show :: Dropout -> String
$cshow :: Dropout -> String
showsPrec :: Int -> Dropout -> ShowS
$cshowsPrec :: Int -> Dropout -> ShowS
Show, forall x. Rep Dropout x -> Dropout
forall x. Dropout -> Rep Dropout x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Dropout x -> Dropout
$cfrom :: forall x. Dropout -> Rep Dropout x
Generic, Dropout -> HList (Parameters Dropout)
Dropout -> HList (Parameters Dropout) -> Dropout
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout
$creplaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout
flattenParameters :: Dropout -> HList (Parameters Dropout)
$cflattenParameters :: Dropout -> HList (Parameters Dropout)
Parameterized)

dropoutForward ::
  forall shape dtype device.
  Dropout ->
  Bool ->
  Tensor device dtype shape ->
  IO (Tensor device dtype shape)
dropoutForward :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout {Double
dropoutProb :: Double
dropoutProb :: Dropout -> Double
..} Bool
dropoutTrain = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropout Double
dropoutProb Bool
dropoutTrain

instance HasForward Dropout (Tensor device dtype shape) (Tensor device dtype shape) where
  forward :: Dropout -> Tensor device dtype shape -> Tensor device dtype shape
forward Dropout
dropout Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout Bool
False Tensor device dtype shape
input
  forwardStoch :: Dropout
-> Tensor device dtype shape -> IO (Tensor device dtype shape)
forwardStoch Dropout
dropout Tensor device dtype shape
input = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout Bool
True Tensor device dtype shape
input

instance Randomizable DropoutSpec Dropout where
  sample :: DropoutSpec -> IO Dropout
sample DropoutSpec {Double
dropoutProbSpec :: Double
dropoutProbSpec :: DropoutSpec -> Double
..} = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Double -> Dropout
Dropout Double
dropoutProbSpec