{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Dropout where
import GHC.Generics (Generic)
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.NN.Functional.Dropout (dropout)
import Torch.GraduallyTyped.Random (SGetGeneratorDevice)
import Torch.GraduallyTyped.Tensor.Type (SGetDevice, Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))
newtype Dropout where
Dropout ::
Double ->
Dropout
deriving stock (Dropout -> Dropout -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Dropout -> Dropout -> Bool
$c/= :: Dropout -> Dropout -> Bool
== :: Dropout -> Dropout -> Bool
$c== :: Dropout -> Dropout -> Bool
Eq, Eq Dropout
Dropout -> Dropout -> Bool
Dropout -> Dropout -> Ordering
Dropout -> Dropout -> Dropout
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Dropout -> Dropout -> Dropout
$cmin :: Dropout -> Dropout -> Dropout
max :: Dropout -> Dropout -> Dropout
$cmax :: Dropout -> Dropout -> Dropout
>= :: Dropout -> Dropout -> Bool
$c>= :: Dropout -> Dropout -> Bool
> :: Dropout -> Dropout -> Bool
$c> :: Dropout -> Dropout -> Bool
<= :: Dropout -> Dropout -> Bool
$c<= :: Dropout -> Dropout -> Bool
< :: Dropout -> Dropout -> Bool
$c< :: Dropout -> Dropout -> Bool
compare :: Dropout -> Dropout -> Ordering
$ccompare :: Dropout -> Dropout -> Ordering
Ord, Int -> Dropout -> ShowS
[Dropout] -> ShowS
Dropout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Dropout] -> ShowS
$cshowList :: [Dropout] -> ShowS
show :: Dropout -> String
$cshow :: Dropout -> String
showsPrec :: Int -> Dropout -> ShowS
$cshowsPrec :: Int -> Dropout -> ShowS
Show, forall x. Rep Dropout x -> Dropout
forall x. Dropout -> Rep Dropout x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Dropout x -> Dropout
$cfrom :: forall x. Dropout -> Rep Dropout x
Generic)
type instance ModelSpec Dropout = Dropout
instance
HasInitialize
Dropout
generatorDevice
Dropout
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec Dropout
-> Generator generatorDevice
-> m (Dropout, Generator generatorDevice)
initialize ModelSpec Dropout
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec Dropout
spec,)
instance HasStateDict Dropout where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec Dropout -> StateDictKey -> m Dropout
fromStateDict ModelSpec Dropout
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec Dropout
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> Dropout -> m ()
toStateDict StateDictKey
_ Dropout
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
( input ~ Tensor gradient layout device dataType shape,
output ~ Tensor gradient layout (device <+> generatorDevice) dataType shape,
generatorOutputDevice ~ (device <+> generatorDevice),
SGetDevice device,
SGetGeneratorDevice generatorDevice
) =>
HasForward
Dropout
input
generatorDevice
output
generatorOutputDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
Dropout
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (Dropout Double
p) input
input Generator generatorDevice
g = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetDevice device, SGetGeneratorDevice generatorDevice,
MonadThrow m) =>
Double
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
dropout Double
p input
input Generator generatorDevice
g