{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Dropout where

import GHC.Generics (Generic)
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.NN.Functional.Dropout (dropout)
import Torch.GraduallyTyped.Random (SGetGeneratorDevice)
import Torch.GraduallyTyped.Tensor.Type (SGetDevice, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))

-- | Given a random generator, randomly zeroes some of the elements of
-- the input tensor with probability 'p' using samples from a Bernoulli distribution.
-- Each channel will be zeroed out independently on every 'forward' call.
newtype Dropout where
  Dropout ::
    -- | probability of an element to be zeroed
    Double ->
    Dropout
  deriving stock (Dropout -> Dropout -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Dropout -> Dropout -> Bool
$c/= :: Dropout -> Dropout -> Bool
== :: Dropout -> Dropout -> Bool
$c== :: Dropout -> Dropout -> Bool
Eq, Eq Dropout
Dropout -> Dropout -> Bool
Dropout -> Dropout -> Ordering
Dropout -> Dropout -> Dropout
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Dropout -> Dropout -> Dropout
$cmin :: Dropout -> Dropout -> Dropout
max :: Dropout -> Dropout -> Dropout
$cmax :: Dropout -> Dropout -> Dropout
>= :: Dropout -> Dropout -> Bool
$c>= :: Dropout -> Dropout -> Bool
> :: Dropout -> Dropout -> Bool
$c> :: Dropout -> Dropout -> Bool
<= :: Dropout -> Dropout -> Bool
$c<= :: Dropout -> Dropout -> Bool
< :: Dropout -> Dropout -> Bool
$c< :: Dropout -> Dropout -> Bool
compare :: Dropout -> Dropout -> Ordering
$ccompare :: Dropout -> Dropout -> Ordering
Ord, Int -> Dropout -> ShowS
[Dropout] -> ShowS
Dropout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Dropout] -> ShowS
$cshowList :: [Dropout] -> ShowS
show :: Dropout -> String
$cshow :: Dropout -> String
showsPrec :: Int -> Dropout -> ShowS
$cshowsPrec :: Int -> Dropout -> ShowS
Show, forall x. Rep Dropout x -> Dropout
forall x. Dropout -> Rep Dropout x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Dropout x -> Dropout
$cfrom :: forall x. Dropout -> Rep Dropout x
Generic)

type instance ModelSpec Dropout = Dropout

instance
  HasInitialize
    Dropout
    generatorDevice
    Dropout
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec Dropout
-> Generator generatorDevice
-> m (Dropout, Generator generatorDevice)
initialize ModelSpec Dropout
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec Dropout
spec,)

instance HasStateDict Dropout where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec Dropout -> StateDictKey -> m Dropout
fromStateDict ModelSpec Dropout
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec Dropout
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> Dropout -> m ()
toStateDict StateDictKey
_ Dropout
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  ( input ~ Tensor gradient layout device dataType shape,
    output ~ Tensor gradient layout (device <+> generatorDevice) dataType shape,
    generatorOutputDevice ~ (device <+> generatorDevice),
    SGetDevice device,
    SGetGeneratorDevice generatorDevice
  ) =>
  HasForward
    Dropout
    input
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
Dropout
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (Dropout Double
p) input
input Generator generatorDevice
g = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetDevice device, SGetGeneratorDevice generatorDevice,
 MonadThrow m) =>
Double
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
dropout Double
p input
input Generator generatorDevice
g