{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Torch.Typed.NN.Dropout where import GHC.Generics import System.IO.Unsafe import Torch.NN (HasForward (..), Randomizable (..)) import Torch.Typed.Functional import Torch.Typed.Parameter import Torch.Typed.Tensor data DropoutSpec where DropoutSpec :: {DropoutSpec -> Double dropoutProbSpec :: Double} -> DropoutSpec deriving (Int -> DropoutSpec -> ShowS [DropoutSpec] -> ShowS DropoutSpec -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [DropoutSpec] -> ShowS $cshowList :: [DropoutSpec] -> ShowS show :: DropoutSpec -> String $cshow :: DropoutSpec -> String showsPrec :: Int -> DropoutSpec -> ShowS $cshowsPrec :: Int -> DropoutSpec -> ShowS Show, DropoutSpec -> DropoutSpec -> Bool forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a /= :: DropoutSpec -> DropoutSpec -> Bool $c/= :: DropoutSpec -> DropoutSpec -> Bool == :: DropoutSpec -> DropoutSpec -> Bool $c== :: DropoutSpec -> DropoutSpec -> Bool Eq) data Dropout where Dropout :: {Dropout -> Double dropoutProb :: Double} -> Dropout deriving (Int -> Dropout -> ShowS [Dropout] -> ShowS Dropout -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [Dropout] -> ShowS $cshowList :: [Dropout] -> ShowS show :: Dropout -> String $cshow :: Dropout -> String showsPrec :: Int -> Dropout -> ShowS $cshowsPrec :: Int -> Dropout -> ShowS Show, forall x. Rep Dropout x -> Dropout forall x. Dropout -> Rep Dropout x forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a $cto :: forall x. Rep Dropout x -> Dropout $cfrom :: forall x. Dropout -> Rep Dropout x Generic, Dropout -> HList (Parameters Dropout) Dropout -> HList (Parameters Dropout) -> Dropout forall f. (f -> HList (Parameters f)) -> (f -> HList (Parameters f) -> f) -> Parameterized f replaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout $creplaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout flattenParameters :: Dropout -> HList (Parameters Dropout) $cflattenParameters :: Dropout -> HList (Parameters Dropout) Parameterized) dropoutForward :: forall shape dtype device. Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward :: forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward Dropout {Double dropoutProb :: Double dropoutProb :: Dropout -> Double ..} Bool dropoutTrain = forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Double -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropout Double dropoutProb Bool dropoutTrain instance HasForward Dropout (Tensor device dtype shape) (Tensor device dtype shape) where forward :: Dropout -> Tensor device dtype shape -> Tensor device dtype shape forward Dropout dropout Tensor device dtype shape input = forall a. IO a -> a unsafePerformIO forall a b. (a -> b) -> a -> b $ forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward Dropout dropout Bool False Tensor device dtype shape input forwardStoch :: Dropout -> Tensor device dtype shape -> IO (Tensor device dtype shape) forwardStoch Dropout dropout Tensor device dtype shape input = forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward Dropout dropout Bool True Tensor device dtype shape input instance Randomizable DropoutSpec Dropout where sample :: DropoutSpec -> IO Dropout sample DropoutSpec {Double dropoutProbSpec :: Double dropoutProbSpec :: DropoutSpec -> Double ..} = forall (m :: * -> *) a. Monad m => a -> m a return forall a b. (a -> b) -> a -> b $ Double -> Dropout Dropout Double dropoutProbSpec