module Torch.Distributions.Categorical
( Categorical (..),
fromProbs,
fromLogits,
)
where
import qualified Torch.DType as D
import qualified Torch.Distributions.Constraints as Constraints
import Torch.Distributions.Distribution
import qualified Torch.Functional as F
import qualified Torch.Functional.Internal as I
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
data Categorical = Categorical
{ Categorical -> Tensor
probs :: D.Tensor,
Categorical -> Tensor
logits :: D.Tensor
}
deriving (Int -> Categorical -> ShowS
[Categorical] -> ShowS
Categorical -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Categorical] -> ShowS
$cshowList :: [Categorical] -> ShowS
show :: Categorical -> String
$cshow :: Categorical -> String
showsPrec :: Int -> Categorical -> ShowS
$cshowsPrec :: Int -> Categorical -> ShowS
Show)
instance Distribution Categorical where
batchShape :: Categorical -> [Int]
batchShape Categorical
d =
if Tensor -> Int
D.numel (Categorical -> Tensor
probs Categorical
d) forall a. Ord a => a -> a -> Bool
> Int
1
then forall a. [a] -> [a]
init (Tensor -> [Int]
D.shape forall a b. (a -> b) -> a -> b
$ Categorical -> Tensor
probs Categorical
d)
else []
eventShape :: Categorical -> [Int]
eventShape Categorical
_d = []
expand :: Categorical -> [Int] -> Categorical
expand Categorical
d [Int]
batchShape' = Tensor -> Categorical
fromProbs forall a b. (a -> b) -> a -> b
$ Tensor -> Bool -> [Int] -> Tensor
F.expand (Categorical -> Tensor
probs Categorical
d) Bool
False (Categorical -> [Int]
paramShape Categorical
d)
where
paramShape :: Categorical -> [Int]
paramShape Categorical
d' = [Int]
batchShape' forall a. Semigroup a => a -> a -> a
<> [Categorical -> Int
numEvents Categorical
d']
support :: Categorical -> Tensor -> Tensor
support Categorical
d = Int -> Int -> Tensor -> Tensor
Constraints.integerInterval Int
0 forall a b. (a -> b) -> a -> b
$ (Categorical -> Int
numEvents Categorical
d) forall a. Num a => a -> a -> a
- Int
1
mean :: Categorical -> Tensor
mean Categorical
d = forall a. Scalar a => a -> Tensor -> Tensor
F.divScalar (Float
0.0 :: Float) ([Int] -> TensorOptions -> Tensor
D.ones (forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d []) TensorOptions
D.float_opts)
variance :: Categorical -> Tensor
variance Categorical
d = forall a. Scalar a => a -> Tensor -> Tensor
F.divScalar (Float
0.0 :: Float) ([Int] -> TensorOptions -> Tensor
D.ones (forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d []) TensorOptions
D.float_opts)
sample :: Categorical -> [Int] -> IO Tensor
sample Categorical
d [Int]
sampleShape = do
let probs2d :: Tensor
probs2d = [Int] -> Tensor -> Tensor
D.reshape [-Int
1, (Categorical -> Int
numEvents Categorical
d)] forall a b. (a -> b) -> a -> b
$ Categorical -> Tensor
probs Categorical
d
Tensor
samples2d <- Tensor -> Tensor
F.transpose2D forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Int -> Bool -> IO Tensor
D.multinomialIO Tensor
probs2d (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sampleShape) Bool
True
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
D.reshape (forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d [Int]
sampleShape) Tensor
samples2d
logProb :: Categorical -> Tensor -> Tensor
logProb Categorical
d Tensor
value =
let value' :: Tensor
value' = Tensor -> Int -> Tensor
I.unsqueeze (DType -> Tensor -> Tensor
F.toDType DType
D.Int64 Tensor
value) (-Int
1 :: Int)
value'' :: Tensor
value'' = Int -> Int -> Tensor -> Tensor
D.select (-Int
1) Int
0 Tensor
value'
in Int -> Tensor -> Tensor
F.squeezeDim (-Int
1) forall a b. (a -> b) -> a -> b
$ Tensor -> Int -> Tensor -> Bool -> Tensor
I.gather (Categorical -> Tensor
logits Categorical
d) (-Int
1 :: Int) Tensor
value'' Bool
False
entropy :: Categorical -> Tensor
entropy Categorical
d = forall a. Scalar a => a -> Tensor -> Tensor
F.mulScalar (-Float
1.0 :: Float) (Dim -> KeepDim -> DType -> Tensor -> Tensor
F.sumDim (Int -> Dim
F.Dim forall a b. (a -> b) -> a -> b
$ -Int
1) KeepDim
F.RemoveDim (Tensor -> DType
D.dtype Tensor
pLogP) Tensor
pLogP)
where
pLogP :: Tensor
pLogP = Categorical -> Tensor
logits Categorical
d Tensor -> Tensor -> Tensor
`F.mul` Categorical -> Tensor
probs Categorical
d
enumerateSupport :: Categorical -> Bool -> Tensor
enumerateSupport Categorical
d Bool
doExpand =
(if Bool
doExpand then \Tensor
t -> Tensor -> Bool -> [Int] -> Tensor
F.expand Tensor
t Bool
False ([-Int
1] forall a. Semigroup a => a -> a -> a
<> forall a. Distribution a => a -> [Int]
batchShape Categorical
d) else forall a. a -> a
id) Tensor
values
where
values :: Tensor
values = [Int] -> Tensor -> Tensor
D.reshape ([-Int
1] forall a. Semigroup a => a -> a -> a
<> forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. Distribution a => a -> [Int]
batchShape Categorical
d) Int
1) forall a b. (a -> b) -> a -> b
$ forall a. TensorLike a => a -> Tensor
D.asTensor [Float
0.0, Float
1.0 :: Float]
numEvents :: Categorical -> Int
numEvents :: Categorical -> Int
numEvents (Categorical Tensor
ps Tensor
_logits) = Int -> Tensor -> Int
D.size (-Int
1) Tensor
ps
fromProbs :: D.Tensor -> Categorical
fromProbs :: Tensor -> Categorical
fromProbs Tensor
ps = Tensor -> Tensor -> Categorical
Categorical Tensor
ps forall a b. (a -> b) -> a -> b
$ Bool -> Tensor -> Tensor
probsToLogits Bool
False Tensor
ps
fromLogits :: D.Tensor -> Categorical
fromLogits :: Tensor -> Categorical
fromLogits Tensor
logits' = Tensor -> Tensor -> Categorical
Categorical (Bool -> Tensor -> Tensor
logitsToProbs Bool
False Tensor
logits') Tensor
logits'