module Torch.Distributions.Categorical
  ( Categorical (..),
    fromProbs,
    fromLogits,
  )
where

import qualified Torch.DType as D
import qualified Torch.Distributions.Constraints as Constraints
import Torch.Distributions.Distribution
import qualified Torch.Functional as F
import qualified Torch.Functional.Internal as I
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D

-- | Creates a categorical distribution parameterized by either :attr:`probs` or
-- | :attr:`logits` (but not both).
-- | .. note::
-- |     It is equivalent to the distribution that :func:`torch.multinomial`
-- |     samples from.
-- | Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
-- | If :attr:`probs` is 1D with length-`K`, each element is the relative
-- | probability of sampling the class at that index.
-- | If :attr:`probs` is 2D, it is treated as a batch of relative probability
-- | vectors.
-- | .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
-- |             and it will be normalized to sum to 1.
-- | See also: `torch.multinomial`
-- | Example::
-- |     >>> m = Categorical.fromProbs $ D.asTensor [ 0.25, 0.25, 0.25, 0.25 ]
-- |     >>> Distribution.sample m  -- equal probability of 0, 1, 2, 3
-- |     tensor(3)
data Categorical = Categorical
  { Categorical -> Tensor
probs :: D.Tensor,
    Categorical -> Tensor
logits :: D.Tensor
  }
  deriving (Int -> Categorical -> ShowS
[Categorical] -> ShowS
Categorical -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Categorical] -> ShowS
$cshowList :: [Categorical] -> ShowS
show :: Categorical -> String
$cshow :: Categorical -> String
showsPrec :: Int -> Categorical -> ShowS
$cshowsPrec :: Int -> Categorical -> ShowS
Show)

instance Distribution Categorical where
  batchShape :: Categorical -> [Int]
batchShape Categorical
d =
    if Tensor -> Int
D.numel (Categorical -> Tensor
probs Categorical
d) forall a. Ord a => a -> a -> Bool
> Int
1
      then forall a. [a] -> [a]
init (Tensor -> [Int]
D.shape forall a b. (a -> b) -> a -> b
$ Categorical -> Tensor
probs Categorical
d)
      else []
  eventShape :: Categorical -> [Int]
eventShape Categorical
_d = []
  expand :: Categorical -> [Int] -> Categorical
expand Categorical
d [Int]
batchShape' = Tensor -> Categorical
fromProbs forall a b. (a -> b) -> a -> b
$ Tensor -> Bool -> [Int] -> Tensor
F.expand (Categorical -> Tensor
probs Categorical
d) Bool
False (Categorical -> [Int]
paramShape Categorical
d)
    where
      paramShape :: Categorical -> [Int]
paramShape Categorical
d' = [Int]
batchShape' forall a. Semigroup a => a -> a -> a
<> [Categorical -> Int
numEvents Categorical
d']
  support :: Categorical -> Tensor -> Tensor
support Categorical
d = Int -> Int -> Tensor -> Tensor
Constraints.integerInterval Int
0 forall a b. (a -> b) -> a -> b
$ (Categorical -> Int
numEvents Categorical
d) forall a. Num a => a -> a -> a
- Int
1
  mean :: Categorical -> Tensor
mean Categorical
d = forall a. Scalar a => a -> Tensor -> Tensor
F.divScalar (Float
0.0 :: Float) ([Int] -> TensorOptions -> Tensor
D.ones (forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d []) TensorOptions
D.float_opts) -- all NaN
  variance :: Categorical -> Tensor
variance Categorical
d = forall a. Scalar a => a -> Tensor -> Tensor
F.divScalar (Float
0.0 :: Float) ([Int] -> TensorOptions -> Tensor
D.ones (forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d []) TensorOptions
D.float_opts) -- all NaN
  sample :: Categorical -> [Int] -> IO Tensor
sample Categorical
d [Int]
sampleShape = do
    let probs2d :: Tensor
probs2d = [Int] -> Tensor -> Tensor
D.reshape [-Int
1, (Categorical -> Int
numEvents Categorical
d)] forall a b. (a -> b) -> a -> b
$ Categorical -> Tensor
probs Categorical
d
    Tensor
samples2d <- Tensor -> Tensor
F.transpose2D forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Int -> Bool -> IO Tensor
D.multinomialIO Tensor
probs2d (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sampleShape) Bool
True
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
D.reshape (forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d [Int]
sampleShape) Tensor
samples2d
  logProb :: Categorical -> Tensor -> Tensor
logProb Categorical
d Tensor
value =
    let value' :: Tensor
value' = Tensor -> Int -> Tensor
I.unsqueeze (DType -> Tensor -> Tensor
F.toDType DType
D.Int64 Tensor
value) (-Int
1 :: Int)
        value'' :: Tensor
value'' = Int -> Int -> Tensor -> Tensor
D.select (-Int
1) Int
0 Tensor
value'
     in Int -> Tensor -> Tensor
F.squeezeDim (-Int
1) forall a b. (a -> b) -> a -> b
$ Tensor -> Int -> Tensor -> Bool -> Tensor
I.gather (Categorical -> Tensor
logits Categorical
d) (-Int
1 :: Int) Tensor
value'' Bool
False
  entropy :: Categorical -> Tensor
entropy Categorical
d = forall a. Scalar a => a -> Tensor -> Tensor
F.mulScalar (-Float
1.0 :: Float) (Dim -> KeepDim -> DType -> Tensor -> Tensor
F.sumDim (Int -> Dim
F.Dim forall a b. (a -> b) -> a -> b
$ -Int
1) KeepDim
F.RemoveDim (Tensor -> DType
D.dtype Tensor
pLogP) Tensor
pLogP)
    where
      pLogP :: Tensor
pLogP = Categorical -> Tensor
logits Categorical
d Tensor -> Tensor -> Tensor
`F.mul` Categorical -> Tensor
probs Categorical
d
  enumerateSupport :: Categorical -> Bool -> Tensor
enumerateSupport Categorical
d Bool
doExpand =
    (if Bool
doExpand then \Tensor
t -> Tensor -> Bool -> [Int] -> Tensor
F.expand Tensor
t Bool
False ([-Int
1] forall a. Semigroup a => a -> a -> a
<> forall a. Distribution a => a -> [Int]
batchShape Categorical
d) else forall a. a -> a
id) Tensor
values
    where
      values :: Tensor
values = [Int] -> Tensor -> Tensor
D.reshape ([-Int
1] forall a. Semigroup a => a -> a -> a
<> forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. Distribution a => a -> [Int]
batchShape Categorical
d) Int
1) forall a b. (a -> b) -> a -> b
$ forall a. TensorLike a => a -> Tensor
D.asTensor [Float
0.0, Float
1.0 :: Float]

numEvents :: Categorical -> Int
numEvents :: Categorical -> Int
numEvents (Categorical Tensor
ps Tensor
_logits) = Int -> Tensor -> Int
D.size (-Int
1) Tensor
ps

fromProbs :: D.Tensor -> Categorical
fromProbs :: Tensor -> Categorical
fromProbs Tensor
ps = Tensor -> Tensor -> Categorical
Categorical Tensor
ps forall a b. (a -> b) -> a -> b
$ Bool -> Tensor -> Tensor
probsToLogits Bool
False Tensor
ps

fromLogits :: D.Tensor -> Categorical
fromLogits :: Tensor -> Categorical
fromLogits Tensor
logits' = Tensor -> Tensor -> Categorical
Categorical (Bool -> Tensor -> Tensor
logitsToProbs Bool
False Tensor
logits') Tensor
logits'