{-# LANGUAGE DataKinds #-}

module Torch.Distributions.Constraints
  ( Constraint,
    dependent,
    boolean,
    integerInterval,
    integerLessThan,
    integerGreaterThan,
    integerLessThanEq,
    integerGreaterThanEq,
    real,
    greaterThan,
    greaterThanEq,
    lessThan,
    lessThanEq,
    interval,
    halfOpenInterval,
    simplex,
    nonNegativeInteger,
    positiveInteger,
    positive,
    unitInterval,
  )
where

import qualified Torch.Functional as F
import qualified Torch.Functional.Internal as I
import Torch.Scalar
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D

type Constraint = D.Tensor -> D.Tensor

dependent :: Constraint
dependent :: Constraint
dependent Tensor
_tensor = forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot determine validity of dependent constraint"

boolean :: Constraint
boolean :: Constraint
boolean Tensor
tensor = (Tensor
tensor Tensor -> Constraint
`F.eq` Constraint
D.zerosLike Tensor
tensor) Tensor -> Constraint
`I.logical_or` (Tensor
tensor Tensor -> Constraint
`F.eq` Constraint
D.onesLike Tensor
tensor)

integerInterval :: Int -> Int -> Constraint
integerInterval :: Int -> Int -> Constraint
integerInterval Int
lower_bound Int
upper_bound Tensor
tensor = (Tensor
tensor Tensor -> Constraint
`F.ge` forall a. Scalar a => a -> Constraint
fullLike' Int
lower_bound Tensor
tensor) Tensor -> Constraint
`I.logical_and` (Tensor
tensor Tensor -> Constraint
`F.le` forall a. Scalar a => a -> Constraint
fullLike' Int
upper_bound Tensor
tensor)

integerLessThan :: Int -> Constraint
integerLessThan :: Int -> Constraint
integerLessThan Int
upper_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.lt` forall a. Scalar a => a -> Constraint
fullLike' Int
upper_bound Tensor
tensor

integerGreaterThan :: Int -> Constraint
integerGreaterThan :: Int -> Constraint
integerGreaterThan Int
lower_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.gt` forall a. Scalar a => a -> Constraint
fullLike' Int
lower_bound Tensor
tensor

integerLessThanEq :: Int -> Constraint
integerLessThanEq :: Int -> Constraint
integerLessThanEq Int
upper_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.le` forall a. Scalar a => a -> Constraint
fullLike' Int
upper_bound Tensor
tensor

integerGreaterThanEq :: Int -> Constraint
integerGreaterThanEq :: Int -> Constraint
integerGreaterThanEq Int
lower_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.ge` forall a. Scalar a => a -> Constraint
fullLike' Int
lower_bound Tensor
tensor

real :: Constraint
real :: Constraint
real = Constraint
I.isfinite

greaterThan :: Float -> Constraint
greaterThan :: Float -> Constraint
greaterThan Float
lower_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.gt` forall a. Scalar a => a -> Constraint
fullLike' Float
lower_bound Tensor
tensor

greaterThanEq :: Float -> Constraint
greaterThanEq :: Float -> Constraint
greaterThanEq Float
lower_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.ge` forall a. Scalar a => a -> Constraint
fullLike' Float
lower_bound Tensor
tensor

lessThan :: Float -> Constraint
lessThan :: Float -> Constraint
lessThan Float
upper_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.lt` forall a. Scalar a => a -> Constraint
fullLike' Float
upper_bound Tensor
tensor

lessThanEq :: Float -> Constraint
lessThanEq :: Float -> Constraint
lessThanEq Float
upper_bound Tensor
tensor = Tensor
tensor Tensor -> Constraint
`F.le` forall a. Scalar a => a -> Constraint
fullLike' Float
upper_bound Tensor
tensor

interval :: Float -> Float -> Constraint
interval :: Float -> Float -> Constraint
interval Float
lower_bound Float
upper_bound Tensor
tensor = (Tensor
tensor Tensor -> Constraint
`F.ge` forall a. Scalar a => a -> Constraint
fullLike' Float
lower_bound Tensor
tensor) Tensor -> Constraint
`I.logical_and` (Tensor
tensor Tensor -> Constraint
`F.le` forall a. Scalar a => a -> Constraint
fullLike' Float
upper_bound Tensor
tensor)

halfOpenInterval :: Float -> Float -> Constraint
halfOpenInterval :: Float -> Float -> Constraint
halfOpenInterval Float
lower_bound Float
upper_bound Tensor
tensor = (Tensor
tensor Tensor -> Constraint
`F.ge` forall a. Scalar a => a -> Constraint
fullLike' Float
lower_bound Tensor
tensor) Tensor -> Constraint
`I.logical_and` (Tensor
tensor Tensor -> Constraint
`F.lt` forall a. Scalar a => a -> Constraint
fullLike' Float
upper_bound Tensor
tensor)

simplex :: Constraint
simplex :: Constraint
simplex Tensor
tensor = Dim -> Bool -> Constraint
F.allDim (Int -> Dim
F.Dim forall a b. (a -> b) -> a -> b
$ -Int
1) Bool
False (Float -> Constraint
greaterThanEq Float
0.0 Tensor
tensor) Tensor -> Constraint
`I.logical_and` (Float -> Constraint
lessThan Float
1e-6 forall a b. (a -> b) -> a -> b
$ Constraint
F.abs forall a b. (a -> b) -> a -> b
$ Tensor
summed Tensor -> Constraint
`F.sub` Constraint
D.onesLike Tensor
summed)
  where
    summed :: Tensor
summed = Dim -> KeepDim -> DType -> Constraint
F.sumDim (Int -> Dim
F.Dim forall a b. (a -> b) -> a -> b
$ -Int
1) KeepDim
F.RemoveDim (Tensor -> DType
D.dtype Tensor
tensor) Tensor
tensor

-- TODO: lowerTriangular
-- TODO: lowerCholesky
-- TODO: positiveDefinite
-- TODO: realVector
-- TODO: cat
-- TODO: stack

nonNegativeInteger :: Constraint
nonNegativeInteger :: Constraint
nonNegativeInteger = Int -> Constraint
integerGreaterThanEq Int
0

positiveInteger :: Constraint
positiveInteger :: Constraint
positiveInteger = Int -> Constraint
integerGreaterThanEq Int
1

positive :: Constraint
positive :: Constraint
positive = Float -> Constraint
greaterThan Float
0.0

unitInterval :: Constraint
unitInterval :: Constraint
unitInterval = Float -> Float -> Constraint
interval Float
0.0 Float
1.0

fullLike' :: (Scalar a) => a -> D.Tensor -> D.Tensor
fullLike' :: forall a. Scalar a => a -> Constraint
fullLike' a
i Tensor
t = forall a. Scalar a => a -> Constraint
F.mulScalar a
i forall a b. (a -> b) -> a -> b
$ Constraint
D.onesLike Tensor
t