{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.Tensor.Other where

import Control.Monad.Catch (MonadThrow)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))
import qualified Torch.Internal.Cast as ATen (cast2, cast3)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen

-- | triu
triu ::
  forall gradient layout device dataType shape.
  -- | diagonal
  Int ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  Tensor gradient layout device dataType shape
triu :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
triu Int
diagonal Tensor gradient layout device dataType shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.triu_tl Tensor gradient layout device dataType shape
input Int
diagonal

-- | tril
tril ::
  forall gradient layout device dataType shape.
  -- | diagonal
  Int ->
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  Tensor gradient layout device dataType shape
tril :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
tril Int
diagonal Tensor gradient layout device dataType shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.tril_tl Tensor gradient layout device dataType shape
input Int
diagonal

-- | masked fill
maskedFill ::
  forall gradient layout device dataType shape value gradient' layout' device' dataType' shape' shape'' m.
  ( Scalar value,
    MonadThrow m,
    Catch (gradient <+> 'Gradient 'WithoutGradient),
    Catch (dataType <+> 'DataType 'Bool),
    shape'' ~ BroadcastShapesF shape shape',
    Catch shape''
  ) =>
  -- | mask
  Tensor gradient layout device dataType shape ->
  -- | value
  value ->
  -- | input
  Tensor gradient' layout' device' dataType' shape' ->
  -- | output
  m
    ( Tensor
        gradient'
        (layout <+> layout' <+> 'Layout 'Dense)
        (device <+> device')
        dataType'
        shape''
    )
maskedFill :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
 Catch (gradient <+> 'Gradient 'WithoutGradient),
 Catch (dataType <+> 'DataType 'Bool),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        gradient'
        (layout <+> (layout' <+> 'Layout 'Dense))
        (device <+> device')
        dataType'
        shape'')
maskedFill Tensor gradient layout device dataType shape
mask value
value Tensor gradient' layout' device' dataType' shape'
input = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.masked_fill_tts Tensor gradient' layout' device' dataType' shape'
input Tensor gradient layout device dataType shape
mask value
value