{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeApplications #-} module Torch.Distributions.Bernoulli ( Bernoulli (..), fromProbs, fromLogits, ) where import qualified Torch.DType as D import qualified Torch.Distributions.Constraints as Constraints import Torch.Distributions.Distribution import qualified Torch.Functional as F import qualified Torch.Functional.Internal as I import Torch.Scalar import qualified Torch.Tensor as D import qualified Torch.TensorFactories as D import Torch.TensorOptions import Torch.Typed.Functional (reductionVal) data Bernoulli = Bernoulli { Bernoulli -> Tensor probs :: D.Tensor, Bernoulli -> Tensor logits :: D.Tensor } deriving (Int -> Bernoulli -> ShowS [Bernoulli] -> ShowS Bernoulli -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [Bernoulli] -> ShowS $cshowList :: [Bernoulli] -> ShowS show :: Bernoulli -> String $cshow :: Bernoulli -> String showsPrec :: Int -> Bernoulli -> ShowS $cshowsPrec :: Int -> Bernoulli -> ShowS Show) instance Distribution Bernoulli where batchShape :: Bernoulli -> [Int] batchShape Bernoulli d = [] eventShape :: Bernoulli -> [Int] eventShape Bernoulli _d = [] expand :: Bernoulli -> [Int] -> Bernoulli expand Bernoulli d = Tensor -> Bernoulli fromProbs forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> Bool -> [Int] -> Tensor F.expand (Bernoulli -> Tensor probs Bernoulli d) Bool False support :: Bernoulli -> Tensor -> Tensor support Bernoulli d = Tensor -> Tensor Constraints.boolean mean :: Bernoulli -> Tensor mean = Bernoulli -> Tensor probs variance :: Bernoulli -> Tensor variance Bernoulli d = Tensor p Tensor -> Tensor -> Tensor `F.mul` (Tensor -> Tensor D.onesLike Tensor p Tensor -> Tensor -> Tensor `F.sub` Tensor p) where p :: Tensor p = Bernoulli -> Tensor probs Bernoulli d sample :: Bernoulli -> [Int] -> IO Tensor sample Bernoulli d = Tensor -> IO Tensor D.bernoulliIO' forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> Bool -> [Int] -> Tensor F.expand (Bernoulli -> Tensor probs Bernoulli d) Bool False forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. Distribution a => a -> [Int] -> [Int] extendedShape Bernoulli d logProb :: Bernoulli -> Tensor -> Tensor logProb Bernoulli d Tensor value = forall a. Scalar a => a -> Tensor -> Tensor F.mulScalar (-Int 1 :: Int) (Tensor -> Tensor -> Tensor bce' (Bernoulli -> Tensor logits Bernoulli d) Tensor value) entropy :: Bernoulli -> Tensor entropy Bernoulli d = Tensor -> Tensor -> Tensor bce' (Bernoulli -> Tensor logits Bernoulli d) forall a b. (a -> b) -> a -> b $ Bernoulli -> Tensor probs Bernoulli d enumerateSupport :: Bernoulli -> Bool -> Tensor enumerateSupport Bernoulli d Bool doExpand = (if Bool doExpand then \Tensor t -> Tensor -> Bool -> [Int] -> Tensor F.expand Tensor t Bool False ([-Int 1] forall a. Semigroup a => a -> a -> a <> forall a. Distribution a => a -> [Int] batchShape Bernoulli d) else forall a. a -> a id) Tensor values where values :: Tensor values = [Int] -> Tensor -> Tensor D.reshape ([-Int 1] forall a. Semigroup a => a -> a -> a <> forall a. Int -> a -> [a] replicate (forall (t :: * -> *) a. Foldable t => t a -> Int length forall a b. (a -> b) -> a -> b $ forall a. Distribution a => a -> [Int] batchShape Bernoulli d) Int 1) forall a b. (a -> b) -> a -> b $ forall a. TensorLike a => a -> Tensor D.asTensor [Float 0.0, Float 1.0 :: Float] bce' :: D.Tensor -> D.Tensor -> D.Tensor bce' :: Tensor -> Tensor -> Tensor bce' Tensor logits Tensor probs = Tensor -> Tensor -> Tensor -> Tensor -> Int -> Tensor I.binary_cross_entropy_with_logits Tensor logits Tensor probs (Tensor -> Tensor D.onesLike Tensor logits) ([Int] -> TensorOptions -> Tensor D.ones [Int -> Tensor -> Int D.size (-Int 1) Tensor logits] TensorOptions D.float_opts) forall a b. (a -> b) -> a -> b $ forall {k} (reduction :: k). KnownReduction reduction => Int reductionVal @(F.ReduceNone) fromProbs :: D.Tensor -> Bernoulli fromProbs :: Tensor -> Bernoulli fromProbs Tensor probs = Tensor -> Tensor -> Bernoulli Bernoulli Tensor probs forall a b. (a -> b) -> a -> b $ Bool -> Tensor -> Tensor probsToLogits Bool False Tensor probs fromLogits :: D.Tensor -> Bernoulli fromLogits :: Tensor -> Bernoulli fromLogits Tensor logits = Tensor -> Tensor -> Bernoulli Bernoulli (Bool -> Tensor -> Tensor probsToLogits Bool False Tensor logits) Tensor logits