module Torch.Distributions.Distribution
( Scale,
Distribution (..),
stddev,
perplexity,
logitsToProbs,
clampProbs,
probsToLogits,
extendedShape,
)
where
import Torch.Distributions.Constraints
import qualified Torch.Functional as F
import qualified Torch.Tensor as D
import Torch.TensorFactories (ones, onesLike)
data Scale = Probs | Logits
class Distribution a where
batchShape :: a -> [Int]
eventShape :: a -> [Int]
expand :: a -> [Int] -> a
support :: a -> Constraint
mean :: a -> D.Tensor
variance :: a -> D.Tensor
sample :: a -> [Int] -> IO D.Tensor
logProb :: a -> D.Tensor -> D.Tensor
entropy :: a -> D.Tensor
enumerateSupport :: a -> Bool -> D.Tensor
stddev :: (Distribution a) => a -> D.Tensor
stddev :: forall a. Distribution a => a -> Tensor
stddev = Tensor -> Tensor
F.sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Distribution a => a -> Tensor
variance
perplexity :: (Distribution a) => a -> D.Tensor
perplexity :: forall a. Distribution a => a -> Tensor
perplexity = Tensor -> Tensor
F.exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Distribution a => a -> Tensor
entropy
logitsToProbs :: Bool -> D.Tensor -> D.Tensor
logitsToProbs :: Bool -> Tensor -> Tensor
logitsToProbs Bool
True = Tensor -> Tensor
F.sigmoid
logitsToProbs Bool
False = Dim -> Tensor -> Tensor
F.softmax (Int -> Dim
F.Dim forall a b. (a -> b) -> a -> b
$ -Int
1)
clampProbs :: D.Tensor -> D.Tensor
clampProbs :: Tensor -> Tensor
clampProbs Tensor
probs =
Float -> Float -> Tensor -> Tensor
F.clamp Float
eps (Float
1.0 forall a. Num a => a -> a -> a
- Float
eps) Tensor
probs
where
eps :: Float
eps = Float
0.000001
probsToLogits :: Bool -> D.Tensor -> D.Tensor
probsToLogits :: Bool -> Tensor -> Tensor
probsToLogits Bool
isBinary Tensor
probs =
if Bool
isBinary
then Tensor -> Tensor
F.log10 Tensor
psClamped Tensor -> Tensor -> Tensor
`F.sub` Tensor -> Tensor
F.log1p (forall a. Scalar a => a -> Tensor -> Tensor
F.mulScalar (-Float
1.0 :: Float) Tensor
psClamped)
else Tensor -> Tensor
F.log10 Tensor
psClamped
where
psClamped :: Tensor
psClamped = Tensor -> Tensor
clampProbs Tensor
probs
extendedShape :: (Distribution a) => a -> [Int] -> [Int]
extendedShape :: forall a. Distribution a => a -> [Int] -> [Int]
extendedShape a
d [Int]
sampleShape =
[Int]
sampleShape forall a. Semigroup a => a -> a -> a
<> forall a. Distribution a => a -> [Int]
batchShape a
d forall a. Semigroup a => a -> a -> a
<> forall a. Distribution a => a -> [Int]
eventShape a
d