{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}

module Torch.Distributions.Bernoulli
  ( Bernoulli (..),
    fromProbs,
    fromLogits,
  )
where

import qualified Torch.DType as D
import qualified Torch.Distributions.Constraints as Constraints
import Torch.Distributions.Distribution
import qualified Torch.Functional as F
import qualified Torch.Functional.Internal as I
import Torch.Scalar
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.TensorOptions
import Torch.Typed.Functional (reductionVal)

data Bernoulli = Bernoulli
  { Bernoulli -> Tensor
probs :: D.Tensor,
    Bernoulli -> Tensor
logits :: D.Tensor
  }
  deriving (Int -> Bernoulli -> ShowS
[Bernoulli] -> ShowS
Bernoulli -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Bernoulli] -> ShowS
$cshowList :: [Bernoulli] -> ShowS
show :: Bernoulli -> String
$cshow :: Bernoulli -> String
showsPrec :: Int -> Bernoulli -> ShowS
$cshowsPrec :: Int -> Bernoulli -> ShowS
Show)

instance Distribution Bernoulli where
  batchShape :: Bernoulli -> [Int]
batchShape Bernoulli
d = []
  eventShape :: Bernoulli -> [Int]
eventShape Bernoulli
_d = []
  expand :: Bernoulli -> [Int] -> Bernoulli
expand Bernoulli
d = Tensor -> Bernoulli
fromProbs forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Bool -> [Int] -> Tensor
F.expand (Bernoulli -> Tensor
probs Bernoulli
d) Bool
False
  support :: Bernoulli -> Tensor -> Tensor
support Bernoulli
d = Tensor -> Tensor
Constraints.boolean
  mean :: Bernoulli -> Tensor
mean = Bernoulli -> Tensor
probs
  variance :: Bernoulli -> Tensor
variance Bernoulli
d = Tensor
p Tensor -> Tensor -> Tensor
`F.mul` (Tensor -> Tensor
D.onesLike Tensor
p Tensor -> Tensor -> Tensor
`F.sub` Tensor
p)
    where
      p :: Tensor
p = Bernoulli -> Tensor
probs Bernoulli
d
  sample :: Bernoulli -> [Int] -> IO Tensor
sample Bernoulli
d = Tensor -> IO Tensor
D.bernoulliIO' forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Bool -> [Int] -> Tensor
F.expand (Bernoulli -> Tensor
probs Bernoulli
d) Bool
False forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Bernoulli
d
  logProb :: Bernoulli -> Tensor -> Tensor
logProb Bernoulli
d Tensor
value = forall a. Scalar a => a -> Tensor -> Tensor
F.mulScalar (-Int
1 :: Int) (Tensor -> Tensor -> Tensor
bce' (Bernoulli -> Tensor
logits Bernoulli
d) Tensor
value)
  entropy :: Bernoulli -> Tensor
entropy Bernoulli
d = Tensor -> Tensor -> Tensor
bce' (Bernoulli -> Tensor
logits Bernoulli
d) forall a b. (a -> b) -> a -> b
$ Bernoulli -> Tensor
probs Bernoulli
d
  enumerateSupport :: Bernoulli -> Bool -> Tensor
enumerateSupport Bernoulli
d Bool
doExpand =
    (if Bool
doExpand then \Tensor
t -> Tensor -> Bool -> [Int] -> Tensor
F.expand Tensor
t Bool
False ([-Int
1] forall a. Semigroup a => a -> a -> a
<> forall a. Distribution a => a -> [Int]
batchShape Bernoulli
d) else forall a. a -> a
id) Tensor
values
    where
      values :: Tensor
values = [Int] -> Tensor -> Tensor
D.reshape ([-Int
1] forall a. Semigroup a => a -> a -> a
<> forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. Distribution a => a -> [Int]
batchShape Bernoulli
d) Int
1) forall a b. (a -> b) -> a -> b
$ forall a. TensorLike a => a -> Tensor
D.asTensor [Float
0.0, Float
1.0 :: Float]

bce' :: D.Tensor -> D.Tensor -> D.Tensor
bce' :: Tensor -> Tensor -> Tensor
bce' Tensor
logits Tensor
probs =
  Tensor -> Tensor -> Tensor -> Tensor -> Int -> Tensor
I.binary_cross_entropy_with_logits
    Tensor
logits
    Tensor
probs
    (Tensor -> Tensor
D.onesLike Tensor
logits)
    ([Int] -> TensorOptions -> Tensor
D.ones [Int -> Tensor -> Int
D.size (-Int
1) Tensor
logits] TensorOptions
D.float_opts)
    forall a b. (a -> b) -> a -> b
$ forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @(F.ReduceNone)

fromProbs :: D.Tensor -> Bernoulli
fromProbs :: Tensor -> Bernoulli
fromProbs Tensor
probs = Tensor -> Tensor -> Bernoulli
Bernoulli Tensor
probs forall a b. (a -> b) -> a -> b
$ Bool -> Tensor -> Tensor
probsToLogits Bool
False Tensor
probs

fromLogits :: D.Tensor -> Bernoulli
fromLogits :: Tensor -> Bernoulli
fromLogits Tensor
logits = Tensor -> Tensor -> Bernoulli
Bernoulli (Bool -> Tensor -> Tensor
probsToLogits Bool
False Tensor
logits) Tensor
logits