{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Torch.GraduallyTyped.Tensor.Other where
import Control.Monad.Catch (MonadThrow)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>))
import qualified Torch.Internal.Cast as ATen (cast2, cast3)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
triu ::
forall gradient layout device dataType shape.
Int ->
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape
triu :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
triu Int
diagonal Tensor gradient layout device dataType shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.triu_tl Tensor gradient layout device dataType shape
input Int
diagonal
tril ::
forall gradient layout device dataType shape.
Int ->
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape
tril :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
tril Int
diagonal Tensor gradient layout device dataType shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.tril_tl Tensor gradient layout device dataType shape
input Int
diagonal
maskedFill ::
forall gradient layout device dataType shape value gradient' layout' device' dataType' shape' shape'' m.
( Scalar value,
MonadThrow m,
Catch (gradient <+> 'Gradient 'WithoutGradient),
Catch (dataType <+> 'DataType 'Bool),
shape'' ~ BroadcastShapesF shape shape',
Catch shape''
) =>
Tensor gradient layout device dataType shape ->
value ->
Tensor gradient' layout' device' dataType' shape' ->
m
( Tensor
gradient'
(layout <+> layout' <+> 'Layout 'Dense)
(device <+> device')
dataType'
shape''
)
maskedFill :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
Catch (gradient <+> 'Gradient 'WithoutGradient),
Catch (dataType <+> 'DataType 'Bool),
shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
gradient'
(layout <+> (layout' <+> 'Layout 'Dense))
(device <+> device')
dataType'
shape'')
maskedFill Tensor gradient layout device dataType shape
mask value
value Tensor gradient' layout' device' dataType' shape'
input = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.masked_fill_tts Tensor gradient' layout' device' dataType' shape'
input Tensor gradient layout device dataType shape
mask value
value