module Torch.Distributions.Distribution
  ( Scale,
    Distribution (..),
    stddev,
    perplexity,
    logitsToProbs,
    clampProbs,
    probsToLogits,
    extendedShape,
  )
where

import Torch.Distributions.Constraints
import qualified Torch.Functional as F
import qualified Torch.Tensor as D
import Torch.TensorFactories (ones, onesLike)

data Scale = Probs | Logits

class Distribution a where
  batchShape :: a -> [Int]
  eventShape :: a -> [Int]
  expand :: a -> [Int] -> a
  support :: a -> Constraint
  mean :: a -> D.Tensor
  variance :: a -> D.Tensor
  sample :: a -> [Int] -> IO D.Tensor
  logProb :: a -> D.Tensor -> D.Tensor
  entropy :: a -> D.Tensor
  enumerateSupport :: a -> Bool -> D.Tensor -- (expand=True)

stddev :: (Distribution a) => a -> D.Tensor -- 'D.Float
stddev :: forall a. Distribution a => a -> Tensor
stddev = Tensor -> Tensor
F.sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Distribution a => a -> Tensor
variance

-- Tensor device 'D.Float '[batchShape]
perplexity :: (Distribution a) => a -> D.Tensor
perplexity :: forall a. Distribution a => a -> Tensor
perplexity = Tensor -> Tensor
F.exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Distribution a => a -> Tensor
entropy

-- | Converts a tensor of logits into probabilities. Note that for the
-- | binary case, each value denotes log odds, whereas for the
-- | multi-dimensional case, the values along the last dimension denote
-- | the log probabilities (possibly unnormalized) of the events.
logitsToProbs :: Bool -> D.Tensor -> D.Tensor -- isBinary=False
logitsToProbs :: Bool -> Tensor -> Tensor
logitsToProbs Bool
True = Tensor -> Tensor
F.sigmoid
logitsToProbs Bool
False = Dim -> Tensor -> Tensor
F.softmax (Int -> Dim
F.Dim forall a b. (a -> b) -> a -> b
$ -Int
1)

clampProbs :: D.Tensor -> D.Tensor
clampProbs :: Tensor -> Tensor
clampProbs Tensor
probs =
  Float -> Float -> Tensor -> Tensor
F.clamp Float
eps (Float
1.0 forall a. Num a => a -> a -> a
- Float
eps) Tensor
probs
  where
    eps :: Float
eps = Float
0.000001 -- torch.finfo(probs.dtype).eps

-- | Converts a tensor of probabilities into logits. For the binary case,
-- | this denotes the probability of occurrence of the event indexed by `1`.
-- | For the multi-dimensional case, the values along the last dimension
-- | denote the probabilities of occurrence of each of the events.
probsToLogits :: Bool -> D.Tensor -> D.Tensor -- isBinary=False
probsToLogits :: Bool -> Tensor -> Tensor
probsToLogits Bool
isBinary Tensor
probs =
  if Bool
isBinary
    then Tensor -> Tensor
F.log10 Tensor
psClamped Tensor -> Tensor -> Tensor
`F.sub` Tensor -> Tensor
F.log1p (forall a. Scalar a => a -> Tensor -> Tensor
F.mulScalar (-Float
1.0 :: Float) Tensor
psClamped)
    else Tensor -> Tensor
F.log10 Tensor
psClamped
  where
    psClamped :: Tensor
psClamped = Tensor -> Tensor
clampProbs Tensor
probs

-- | Returns the size of the sample returned by the distribution, given
-- | a `sampleShape`. Note, that the batch and event shapes of a distribution
-- | instance are fixed at the time of construction. If this is empty, the
-- | returned shape is upcast to (1,).
-- | Args:
-- |     sampleShape (torch.Size): the size of the sample to be drawn.
extendedShape :: (Distribution a) => a -> [Int] -> [Int]
extendedShape :: forall a. Distribution a => a -> [Int] -> [Int]
extendedShape a
d [Int]
sampleShape =
  [Int]
sampleShape forall a. Semigroup a => a -> a -> a
<> forall a. Distribution a => a -> [Int]
batchShape a
d forall a. Semigroup a => a -> a -> a
<> forall a. Distribution a => a -> [Int]
eventShape a
d