{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
module Torch.Typed.Functional where
import Control.Arrow ((&&&))
import Data.Finite
import qualified Data.Int as I
import Data.Kind
( Constraint,
Type,
)
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.ForeignPtr
import GHC.Generics (Generic)
import GHC.Natural (Natural)
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.Functional
( Reduction (..),
Tri (..),
isUpper,
kOne,
)
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen.Managed
import qualified Torch.Internal.Managed.Native as ATen.Managed
import qualified Torch.Internal.Managed.Type.Scalar as ATen.Managed
import qualified Torch.Internal.Managed.Type.Tensor as ATen.Managed
import qualified Torch.Internal.Managed.Type.Tuple as ATen.Managed
import qualified Torch.Internal.Type as ATen
import qualified Torch.Scalar as D
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import qualified Torch.TensorOptions as D
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Tensor
import Prelude hiding
( abs,
acos,
acosh,
all,
any,
asin,
asinh,
atan,
atanh,
ceil,
cos,
cosh,
exp,
floor,
isNaN,
log,
max,
min,
round,
sin,
sinh,
tan,
tanh,
)
bitwiseNot ::
forall device shape.
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape
bitwiseNot :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape -> Tensor device 'Bool shape
bitwiseNot Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.bitwise_not_t Tensor device 'Bool shape
input
logicalNot ::
forall device shape.
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape
logicalNot :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalNot Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_not_t Tensor device 'Bool shape
input
logicalXor ::
forall device shape.
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape
logicalXor :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalXor Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_xor_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other
logicalAnd ::
forall device shape.
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape
logicalAnd :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalAnd Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_and_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other
logicalOr ::
forall device shape.
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape
logicalOr :: forall (device :: (DeviceType, Nat)) (shape :: [Nat]).
Tensor device 'Bool shape
-> Tensor device 'Bool shape -> Tensor device 'Bool shape
logicalOr Tensor device 'Bool shape
self Tensor device 'Bool shape
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logical_or_tt Tensor device 'Bool shape
self Tensor device 'Bool shape
other
type family SumDType (dtype :: D.DType) :: D.DType where
SumDType D.Bool = D.Int64
SumDType D.UInt8 = D.Int64
SumDType D.Int8 = D.Int64
SumDType D.Int16 = D.Int64
SumDType D.Int32 = D.Int64
SumDType D.Int64 = D.Int64
SumDType D.Half = D.Half
SumDType D.Float = D.Float
SumDType D.Double = D.Double
type family SumDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
SumDTypeIsValid '( 'D.CPU, 0) dtype = DTypeIsNotHalf '( 'D.CPU, 0) dtype
SumDTypeIsValid '( 'D.CUDA, _) dtype = ()
SumDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
sumAll ::
forall shape dtype' dtype device.
( SumDTypeIsValid device dtype,
dtype' ~ SumDType dtype
) =>
Tensor device dtype shape ->
Tensor device dtype' '[]
sumAll :: forall (shape :: [Nat]) (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)).
(SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' '[]
sumAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sum_t Tensor device dtype shape
input
sumDim ::
forall d shape shape' dtype dtype' device.
( KnownNat d,
shape' ~ DropValue shape d,
SumDTypeIsValid device dtype,
dtype' ~ SumDType dtype
) =>
Tensor device dtype shape ->
Tensor device dtype' shape'
sumDim :: forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
sumDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.sum_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @d)
abs ::
forall shape dtype device.
(StandardDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
abs :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
abs Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.abs_t Tensor device dtype shape
input
ceil ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
ceil :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
ceil Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.ceil_t Tensor device dtype shape
input
floor ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
floor :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
floor Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.floor_t Tensor device dtype shape
input
type family MinMaxDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
MinMaxDTypeIsValid '( 'D.CPU, 0) dtype = DTypeIsNotHalf '( 'D.CPU, 0) dtype
MinMaxDTypeIsValid '( 'D.CUDA, _) dtype = ()
MinMaxDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
min ::
forall shape dtype device.
( MinMaxDTypeIsValid device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
Tensor device dtype '[]
min :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(MinMaxDTypeIsValid device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
min Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.min_t Tensor device dtype shape
input
max ::
forall shape dtype device.
( MinMaxDTypeIsValid device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
Tensor device dtype '[]
max :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(MinMaxDTypeIsValid device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.max_t Tensor device dtype shape
input
type family MeanDTypeValidation (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
MeanDTypeValidation '(deviceType, deviceIndex) dtype =
( DTypeIsFloatingPoint '(deviceType, deviceIndex) dtype,
DTypeIsNotHalf '(deviceType, deviceIndex) dtype
)
meanAll ::
forall shape dtype device.
( MeanDTypeValidation device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
Tensor device dtype '[]
meanAll :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
meanAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mean_t Tensor device dtype shape
input
unsafeMeanAll ::
forall shape dtype device.
MeanDTypeValidation device dtype =>
Tensor device dtype shape ->
Tensor device dtype '[]
unsafeMeanAll :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
MeanDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype '[]
unsafeMeanAll
Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mean_t Tensor device dtype shape
input
meanDim ::
forall dim shape' shape dtype device.
( KnownNat dim,
shape' ~ DropValue shape dim,
MeanDTypeValidation device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
meanDim :: forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
meanDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)
meanNamedDim ::
forall dim shape' shape dtype device.
( KnownNat (FindDim dim shape),
shape' ~ DropNamedValue shape dim,
MeanDTypeValidation device dtype
) =>
NamedTensor device dtype shape ->
NamedTensor device dtype shape'
meanNamedDim :: forall (dim :: Size) (shape' :: Shape) (shape :: Shape)
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat (FindDim dim shape), shape' ~ DropNamedValue shape dim,
MeanDTypeValidation device dtype) =>
NamedTensor device dtype shape -> NamedTensor device dtype shape'
meanNamedDim NamedTensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tl NamedTensor device dtype shape
input Int
_dim
where
_dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @(FindDim dim shape)
mean ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
MeanDTypeValidation device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
mean :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
mean Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.mean_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
medianAll ::
forall shape dtype device.
( StandardDTypeValidation device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
Tensor device dtype '[]
medianAll :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype '[]
medianAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.median_t Tensor device dtype shape
input
medianDim ::
forall dim shape' shape dtype device.
( KnownNat dim,
shape' ~ DropValue shape dim,
StandardDTypeValidation device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
( Tensor device dtype shape',
Tensor device 'D.Int64 shape'
)
medianDim :: forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
medianDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.median_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)
median ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
(Tensor device dtype shape', Tensor device 'D.Int64 shape')
median :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
median Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.median_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
mode ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype,
AllDimsPositive shape
) =>
Tensor device dtype shape ->
(Tensor device dtype shape', Tensor device 'D.Int64 shape')
mode :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
mode Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.mode_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
addScalar ::
forall a shape dtype device.
D.Scalar a =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
addScalar :: forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
addScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.add_ts Tensor device dtype shape
input a
a
subScalar ::
forall a shape dtype device.
D.Scalar a =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
subScalar :: forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
subScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.sub_ts Tensor device dtype shape
input a
a
mulScalar ::
forall a shape dtype device.
D.Scalar a =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
mulScalar :: forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.mul_ts Tensor device dtype shape
input a
a
divScalar ::
forall a shape dtype device.
D.Scalar a =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
divScalar :: forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.div_ts Tensor device dtype shape
input a
a
powScalar ::
forall a shape dtype device.
D.Scalar a =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
powScalar :: forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
powScalar a
a Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.pow_ts Tensor device dtype shape
input a
a
erf ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
erf :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
erf Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.erf_t Tensor device dtype shape
input
exp ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
exp :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
exp Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.exp_t Tensor device dtype shape
input
log1p ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
log1p :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log1p Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log1p_t Tensor device dtype shape
input
log2 ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
log2 :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log2_t Tensor device dtype shape
input
log10 ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
log10 :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
log10 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log10_t Tensor device dtype shape
input
pow ::
forall shape'' shape shape' dtype device.
( BasicArithmeticDTypeIsValid device dtype,
shape'' ~ Broadcast shape shape'
) =>
Tensor device dtype shape ->
Tensor device dtype shape' ->
Tensor device dtype shape''
pow :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
shape'' ~ Broadcast shape shape') =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
pow Tensor device dtype shape
exponent Tensor device dtype shape'
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.pow_tt Tensor device dtype shape'
input Tensor device dtype shape
exponent
relu ::
forall shape dtype device t.
( StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape
) =>
t ->
t
relu :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
relu t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.relu_t (forall a. a -> Wrap a
Wrap t
input)
selu ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
selu :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
selu Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.selu_t Tensor device dtype shape
input
mish ::
forall shape dtype device.
( StandardFloatingPointDTypeValidation device dtype,
BasicArithmeticDTypeIsValid device dtype,
shape ~ Broadcast shape shape
) =>
Tensor device dtype shape ->
Tensor device dtype shape
mish :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(StandardFloatingPointDTypeValidation device dtype,
BasicArithmeticDTypeIsValid device dtype,
shape ~ Broadcast shape shape) =>
Tensor device dtype shape -> Tensor device dtype shape
mish = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul forall (m :: Size) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tanh forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
softplus (Float
1 :: Float) Float
20
sigmoid ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
sigmoid :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
sigmoid Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sigmoid_t Tensor device dtype shape
input
sin ::
forall shape dtype device t.
( StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape
) =>
t ->
t
sin :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
sin t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sin_t (forall a. a -> Wrap a
Wrap t
input)
sinh ::
forall shape dtype device t.
( StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape
) =>
t ->
t
sinh :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
sinh t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sinh_t (forall a. a -> Wrap a
Wrap t
input)
cos ::
forall shape dtype device t.
( StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape
) =>
t ->
t
cos :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
cos t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.cos_t (forall a. a -> Wrap a
Wrap t
input)
sqrt ::
forall shape dtype device t.
( StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape
) =>
t ->
t
sqrt :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
sqrt t
input = forall a. Wrap a -> a
unWrap forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sqrt_t (forall a. a -> Wrap a
Wrap t
input)
tanh ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
tanh :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tanh Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.tanh_t Tensor device dtype shape
input
type family ConditionalReduction (shape :: [Nat]) (reduction :: Reduction) :: [Nat] where
ConditionalReduction shape ReduceNone = shape
ConditionalReduction shape _ = '[]
class KnownReduction reduction where
reductionVal :: Int
instance KnownReduction ReduceNone where
reductionVal :: Int
reductionVal = Int
0
instance KnownReduction ReduceMean where
reductionVal :: Int
reductionVal = Int
1
instance KnownReduction ReduceSum where
reductionVal :: Int
reductionVal = Int
2
binaryCrossEntropy ::
forall (reduction :: Reduction) shape shape' dtype device.
( KnownReduction reduction,
shape' ~ ConditionalReduction shape reduction,
StandardFloatingPointDTypeValidation device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape ->
Tensor device dtype shape ->
Tensor device dtype shape'
binaryCrossEntropy :: forall (reduction :: Reduction) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction,
shape' ~ ConditionalReduction shape reduction,
StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype shape'
binaryCrossEntropy Tensor device dtype shape
weight Tensor device dtype shape
prediction Tensor device dtype shape
target =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.binary_cross_entropy_tttl
Tensor device dtype shape
prediction
Tensor device dtype shape
target
Tensor device dtype shape
weight
(forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
mseLoss ::
forall (reduction :: Reduction) shape shape' dtype device.
( KnownReduction reduction,
shape' ~ ConditionalReduction shape reduction,
StandardFloatingPointDTypeValidation device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape ->
Tensor device dtype shape'
mseLoss :: forall (reduction :: Reduction) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction,
shape' ~ ConditionalReduction shape reduction,
StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape'
mseLoss Tensor device dtype shape
prediction Tensor device dtype shape
target =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.mse_loss_ttl
Tensor device dtype shape
prediction
Tensor device dtype shape
target
(forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
softmax ::
forall dim shape dtype device.
( KnownNat dim,
DimOutOfBoundCheck shape dim,
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape
softmax :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
softmax Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.Managed.softmax_tls Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype)
logSoftmax ::
forall dim shape dtype device.
( KnownNat dim,
DimOutOfBoundCheck shape dim,
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape
logSoftmax :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
logSoftmax Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> ScalarType -> IO (ForeignPtr Tensor)
ATen.Managed.log_softmax_tls Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype)
type family Square (shape :: [Nat]) :: [Nat] where
Square (n : n : '[]) = '[n, n]
Square (b : n : n : '[]) = '[b, n, n]
Square _ = TypeError (Text "This shape must be square matrix or batch + square matrix.")
type family VectorOfSquare (shape :: [Nat]) :: [Nat] where
VectorOfSquare (n : n : '[]) = '[n]
VectorOfSquare (b : n : n : '[]) = '[b, n]
VectorOfSquare _ = TypeError (Text "This shape must be square matrix or batch + square matrix.")
type family FstSquareDim (shape :: [Nat]) :: Nat where
FstSquareDim (n : m : '[]) = n
FstSquareDim (b : n : m : '[]) = n
FstSquareDim _ = TypeError (Text "Can not get first dimention of matrix or batch + matrix.")
type family InverseShapeIsValid (device :: (D.DeviceType, Nat)) (shape :: [Nat]) :: Constraint where
InverseShapeIsValid '( 'D.CPU, 0) _ = ()
InverseShapeIsValid '( 'D.CUDA, _) shape = AllDimsPositive shape
type family InverseDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
InverseDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
InverseDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
)
InverseDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
inverse ::
forall shape shape' dtype device.
( shape' ~ Square shape,
InverseShapeIsValid device shape,
InverseDTypeIsValid device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
inverse :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ Square shape, InverseShapeIsValid device shape,
InverseDTypeIsValid device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape'
inverse Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.inverse_t Tensor device dtype shape
input
type family SymeigDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
SymeigDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
SymeigDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
)
SymeigDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
symeig ::
forall shape shape' shape'' dtype device.
( shape' ~ VectorOfSquare shape,
shape'' ~ Square shape,
SymeigDTypeIsValid device dtype
) =>
Tri ->
Tensor device dtype shape ->
( Tensor device dtype shape',
Tensor device dtype shape''
)
symeig :: forall (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape' ~ VectorOfSquare shape, shape'' ~ Square shape,
SymeigDTypeIsValid device dtype) =>
Tri
-> Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device dtype shape'')
symeig Tri
upper Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.symeig_tbb Tensor device dtype shape
input Bool
True Bool
boolUpper
where
boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper
symeigvalues ::
forall shape shape' dtype device.
( shape' ~ VectorOfSquare shape,
SymeigDTypeIsValid device dtype
) =>
Tri ->
Tensor device dtype shape ->
Tensor device dtype shape'
symeigvalues :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ VectorOfSquare shape, SymeigDTypeIsValid device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
symeigvalues Tri
upper Tensor device dtype shape
input = forall a b. (a, b) -> a
fst forall (shape'' :: [Nat]).
(Tensor device dtype shape', Tensor device dtype shape'')
symeig'
where
boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper
symeig' :: (Tensor device dtype shape', Tensor device dtype shape'')
symeig' :: forall (shape'' :: [Nat]).
(Tensor device dtype shape', Tensor device dtype shape'')
symeig' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.symeig_tbb Tensor device dtype shape
input Bool
False Bool
boolUpper
data EigenVectors = EnableEigenVectors | DisableEigenVectors
class KnownEigenVectors a where
enableEigenVectors :: Bool
instance KnownEigenVectors EnableEigenVectors where
enableEigenVectors :: Bool
enableEigenVectors = Bool
True
instance KnownEigenVectors DisableEigenVectors where
enableEigenVectors :: Bool
enableEigenVectors = Bool
False
type family ConditionalEigenVectors (eigenvectors :: EigenVectors) (n :: Nat) :: [Nat] where
ConditionalEigenVectors EnableEigenVectors n = '[n, n]
ConditionalEigenVectors DisableEigenVectors _ = '[0]
type family EigDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
EigDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
EigDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
)
EigDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
eig ::
forall eigenvectors n shape dtype device.
( KnownNat n,
KnownEigenVectors eigenvectors,
shape ~ ConditionalEigenVectors eigenvectors n,
EigDTypeIsValid device dtype
) =>
Tensor device dtype '[n, n] ->
( Tensor device dtype '[n, 2],
Tensor device dtype shape
)
eig :: forall (eigenvectors :: EigenVectors) (n :: Nat) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownEigenVectors eigenvectors,
shape ~ ConditionalEigenVectors eigenvectors n,
EigDTypeIsValid device dtype) =>
Tensor device dtype '[n, n]
-> (Tensor device dtype '[n, 2], Tensor device dtype shape)
eig Tensor device dtype '[n, n]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.eig_tb Tensor device dtype '[n, n]
input (forall {k} (a :: k). KnownEigenVectors a => Bool
enableEigenVectors @eigenvectors)
type family SVDShapes (shape :: [Nat]) (reduced :: ReducedSVD) :: ([Nat], [Nat], [Nat]) where
SVDShapes '[0, n] 'ThinSVD = '( '[0, 0], '[0], '[n, n])
SVDShapes '[m, n] 'ThinSVD = '( '[m, Min m n], '[Min m n], '[n, Min m n])
SVDShapes '[m, n] 'FullSVD = '( '[m, m], '[Min m n], '[n, n])
SVDShapes '[b, 0, n] 'ThinSVD = '( '[b, 0, 0], '[b, 0], '[b, n, n])
SVDShapes '[b, m, n] 'ThinSVD = '( '[b, m, Min m n], '[b, Min m n], '[b, n, Min m n])
SVDShapes '[b, m, n] 'FullSVD = '( '[b, m, m], '[b, Min m n], '[b, n, n])
SVDShapes _ _ = TypeError (Text "A singular value decomposition can only be computed for 2D matrices for at most one batch dimension.")
data ReducedSVD = ThinSVD | FullSVD
class KnownReducedSVD (reduced :: ReducedSVD) where
reducedSVD :: Bool
instance KnownReducedSVD ThinSVD where
reducedSVD :: Bool
reducedSVD = Bool
True
instance KnownReducedSVD FullSVD where
reducedSVD :: Bool
reducedSVD = Bool
False
type family SVDDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
SVDDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
SVDDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
)
SVDDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
svd ::
forall reduced shape shapeU shapeS shapeV dtype device.
( KnownReducedSVD reduced,
'(shapeU, shapeS, shapeV) ~ SVDShapes shape reduced,
SVDDTypeIsValid device dtype
) =>
Tensor device dtype shape ->
( Tensor device dtype shapeU,
Tensor device dtype shapeS,
Tensor device dtype shapeV
)
svd :: forall (reduced :: ReducedSVD) (shape :: [Nat]) (shapeU :: [Nat])
(shapeS :: [Nat]) (shapeV :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownReducedSVD reduced,
'(shapeU, shapeS, shapeV) ~ SVDShapes shape reduced,
SVDDTypeIsValid device dtype) =>
Tensor device dtype shape
-> (Tensor device dtype shapeU, Tensor device dtype shapeS,
Tensor device dtype shapeV)
svd Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.Managed.svd_tbb Tensor device dtype shape
input (forall (reduced :: ReducedSVD). KnownReducedSVD reduced => Bool
reducedSVD @reduced) Bool
True
type family CholeskyDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
CholeskyDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
CholeskyDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
)
CholeskyDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
cholesky ::
forall shape shape' dtype device.
( shape' ~ Square shape,
CholeskyDTypeIsValid device dtype
) =>
Tri ->
Tensor device dtype shape ->
Tensor device dtype shape'
cholesky :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ Square shape, CholeskyDTypeIsValid device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
cholesky Tri
upper Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_tb Tensor device dtype shape
input Bool
boolUpper
where
boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper
choleskyInverse ::
forall n dtype device.
( 1 <= n,
CholeskyDTypeIsValid device dtype
) =>
Tri ->
Tensor device dtype '[n, n] ->
Tensor device dtype '[n, n]
choleskyInverse :: forall (n :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= n, CholeskyDTypeIsValid device dtype) =>
Tri -> Tensor device dtype '[n, n] -> Tensor device dtype '[n, n]
choleskyInverse Tri
upper Tensor device dtype '[n, n]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_inverse_tb Tensor device dtype '[n, n]
input Bool
boolUpper
where
boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper
choleskySolve ::
forall m_k m_m dtype device.
( Square m_m ~ m_m,
FstSquareDim m_m ~ FstSquareDim m_k,
1 <= FstSquareDim m_m,
CholeskyDTypeIsValid device dtype
) =>
Tri ->
Tensor device dtype m_k ->
Tensor device dtype m_m ->
Tensor device dtype m_k
choleskySolve :: forall (m_k :: [Nat]) (m_m :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(Square m_m ~ m_m, FstSquareDim m_m ~ FstSquareDim m_k,
1 <= FstSquareDim m_m, CholeskyDTypeIsValid device dtype) =>
Tri
-> Tensor device dtype m_k
-> Tensor device dtype m_m
-> Tensor device dtype m_k
choleskySolve Tri
upper Tensor device dtype m_k
b Tensor device dtype m_m
u =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.cholesky_solve_ttb Tensor device dtype m_k
b Tensor device dtype m_m
u Bool
boolUpper
where
boolUpper :: Bool
boolUpper = Tri -> Bool
isUpper Tri
upper
type family SolveDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
SolveDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
SolveDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
)
SolveDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
solve ::
forall m_k m_m dtype device.
( Square m_m ~ m_m,
FstSquareDim m_m ~ FstSquareDim m_k,
1 <= FstSquareDim m_m,
SolveDTypeIsValid device dtype
) =>
Tensor device dtype m_k ->
Tensor device dtype m_m ->
( Tensor device dtype m_k,
Tensor device dtype m_m
)
solve :: forall (m_k :: [Nat]) (m_m :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(Square m_m ~ m_m, FstSquareDim m_m ~ FstSquareDim m_k,
1 <= FstSquareDim m_m, SolveDTypeIsValid device dtype) =>
Tensor device dtype m_k
-> Tensor device dtype m_m
-> (Tensor device dtype m_k, Tensor device dtype m_m)
solve Tensor device dtype m_k
b Tensor device dtype m_m
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.solve_tt Tensor device dtype m_k
b Tensor device dtype m_m
a
geqrf ::
forall m n dtype device.
Tensor device dtype '[m, n] ->
( Tensor device dtype '[m, n],
Tensor device dtype '[Min m n]
)
geqrf :: forall (m :: Nat) (n :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype '[m, n]
-> (Tensor device dtype '[m, n], Tensor device dtype '[Min m n])
geqrf Tensor device dtype '[m, n]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.geqrf_t Tensor device dtype '[m, n]
input
orgqr ::
forall m n dtype device.
( KnownNat n,
KnownNat m,
n <= m
) =>
Tensor device dtype '[m, n] ->
Tensor device dtype '[n] ->
Tensor device dtype '[m, n]
orgqr :: forall (m :: Nat) (n :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, n <= m) =>
Tensor device dtype '[m, n]
-> Tensor device dtype '[n] -> Tensor device dtype '[m, n]
orgqr Tensor device dtype '[m, n]
a Tensor device dtype '[n]
tau = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.orgqr_tt Tensor device dtype '[m, n]
a Tensor device dtype '[n]
tau
sign ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
sign :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
sign Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.sign_t Tensor device dtype shape
input
type family SetValue (shape :: [Nat]) (i :: Nat) (j :: Nat) :: [Nat] where
SetValue '[] _ _ = '[]
SetValue (x : xs) 0 j = j : xs
SetValue (x : xs) i j = x : SetValue xs (i -1) j
type family GetValue (shape :: [Nat]) (i :: Nat) :: Nat where
GetValue '[] _ = TypeError (Text "Can not find a element in the list.")
GetValue (x : xs) 0 = x
GetValue (x : xs) i = GetValue xs (i -1)
type family Transpose (shape :: [Nat]) (dim0 :: Nat) (dim1 :: Nat) :: [Nat] where
Transpose s d0 d1 = (SetValue (SetValue s d0 (GetValue s d1)) d1 (GetValue s d0))
transpose ::
forall n m shape shape' dtype device.
( KnownNat n,
KnownNat m,
shape' ~ Transpose shape n m
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
transpose :: forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.transpose_tll Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @n) (forall (n :: Nat). KnownNat n => Int
natValI @m)
transpose2D ::
forall (i :: Nat) (j :: Nat) dtype device.
Tensor device dtype '[i, j] ->
Tensor device dtype '[j, i]
transpose2D :: forall (i :: Nat) (j :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype '[i, j] -> Tensor device dtype '[j, i]
transpose2D = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @0 @1
class KnownTri (tri :: Tri) where
triVal :: Tri
instance KnownTri Upper where
triVal :: Tri
triVal = Tri
Upper
instance KnownTri Lower where
triVal :: Tri
triVal = Tri
Lower
type family DiagSize (tri :: Tri) (index :: Nat) (m :: Nat) (n :: Nat) :: Nat where
DiagSize 'Upper i m n =
If
(i <=? n)
(Min m (n - i))
( TypeError
( Text "For a matrix with shape "
:<>: ShowType '[m, n]
:<>: Text ", the maximum index for an upper diagonal is "
:<>: ShowType n
:<>: Text ", but asked for index "
:<>: ShowType i
)
)
DiagSize 'Lower i m n =
If
(i <=? m)
(Min (m - i) n)
( TypeError
( Text "For a matrix with shape "
:<>: ShowType '[m, n]
:<>: Text ", the maximum index for a lower diagonal is "
:<>: ShowType m
:<>: Text ", but asked for index "
:<>: ShowType i
)
)
type family DiagShape (tri :: Tri) (index :: Nat) (shape :: [Nat]) :: [Nat] where
DiagShape _ i '[n] = '[n + i, n + i]
DiagShape tri i '[m, n] = '[DiagSize tri i m n]
DiagShape _ _ shape =
TypeError
( Text "The input must be a matrix or a vector, but it has "
:<>: ShowType (ListLength shape)
:<>: Text " dimensions."
)
diag ::
forall tri index shape shape' device dtype.
( KnownTri tri,
KnownNat index,
StandardDTypeValidation device dtype,
shape' ~ DiagShape tri index shape
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
diag :: forall (tri :: Tri) (index :: Nat) (shape :: [Nat])
(shape' :: [Nat]) (device :: (DeviceType, Nat)) (dtype :: DType).
(KnownTri tri, KnownNat index,
StandardDTypeValidation device dtype,
shape' ~ DiagShape tri index shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
diag Tensor device dtype shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.tensor_diag_l Tensor device dtype shape
t forall a b. (a -> b) -> a -> b
$
case forall (tri :: Tri). KnownTri tri => Tri
triVal @tri of
Tri
Upper -> forall (n :: Nat). KnownNat n => Int
natValI @index
Tri
Lower -> - forall (n :: Nat). KnownNat n => Int
natValI @index
all ::
forall shape device.
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool '[]
all :: forall (shape :: [Nat]) (device :: (DeviceType, Nat)).
Tensor device 'Bool shape -> Tensor device 'Bool '[]
all Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.all_t Tensor device 'Bool shape
input
any ::
forall shape device.
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool '[]
any :: forall (shape :: [Nat]) (device :: (DeviceType, Nat)).
Tensor device 'Bool shape -> Tensor device 'Bool '[]
any Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.any_t Tensor device 'Bool shape
input
data KeepOrDropDim = KeepDim | DropDim
class KnownKeepOrDropDim keepOrDropDim where
keepOrDropDimVal :: Bool
instance KnownKeepOrDropDim KeepDim where
keepOrDropDimVal :: Bool
keepOrDropDimVal = Bool
True
instance KnownKeepOrDropDim DropDim where
keepOrDropDimVal :: Bool
keepOrDropDimVal = Bool
False
type family ConditionalDropDimension (shape :: [Nat]) (dim :: Nat) (keepOrDropDim :: KeepOrDropDim) :: [Nat] where
ConditionalDropDimension '[] _ _ = TypeError (Text "The specified dimension is not available.")
ConditionalDropDimension (x : xs) 0 KeepDim = 1 ': xs
ConditionalDropDimension (x : xs) 0 DropDim = xs
ConditionalDropDimension (x : xs) i keepOrDropDim = x ': ConditionalDropDimension xs (i - 1) keepOrDropDim
allDim ::
forall dim keepOrDropDim shape' shape device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim
) =>
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape'
allDim :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device 'Bool shape -> Tensor device 'Bool shape'
allDim Tensor device 'Bool shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.all_tlb Tensor device 'Bool shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
anyDim ::
forall dim keepOrDropDim shape' shape device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim
) =>
Tensor device 'D.Bool shape ->
Tensor device 'D.Bool shape'
anyDim :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device 'Bool shape -> Tensor device 'Bool shape'
anyDim Tensor device 'Bool shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.any_tlb Tensor device 'Bool shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
dropout ::
forall shape dtype device.
Double ->
Bool ->
Tensor device dtype shape ->
IO (Tensor device dtype shape)
dropout :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropout Double
p Bool
train Tensor device dtype shape
input = forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.dropout_tdb Tensor device dtype shape
input Double
p Bool
train
featureDropout ::
forall shape dtype device.
Double ->
Bool ->
Tensor device dtype shape ->
Tensor device dtype shape
featureDropout :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
featureDropout Double
p Bool
train Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.feature_dropout_tdb Tensor device dtype shape
input Double
p Bool
train
alphaDropout ::
forall shape dtype device.
Double ->
Bool ->
Tensor device dtype shape ->
Tensor device dtype shape
alphaDropout :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
alphaDropout Double
p Bool
train Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.alpha_dropout_tdb Tensor device dtype shape
input Double
p Bool
train
featureAlphaDropout ::
forall shape dtype device.
Double ->
Bool ->
Tensor device dtype shape ->
Tensor device dtype shape
featureAlphaDropout :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double
-> Bool -> Tensor device dtype shape -> Tensor device dtype shape
featureAlphaDropout Double
p Bool
train Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> CDouble -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.feature_alpha_dropout_tdb Tensor device dtype shape
input Double
p Bool
train
acos ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
acos :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
acos Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.acos_t Tensor device dtype shape
input
avgPool1d ::
forall
kernelSize
stride
padding
channelSize
inputSize
batchSize
outputSize
dtype
device.
( All KnownNat '[kernelSize, stride, padding, channelSize, inputSize, batchSize],
ConvSideCheck inputSize kernelSize stride padding outputSize
) =>
Tensor device dtype '[batchSize, channelSize, inputSize] ->
Tensor device dtype '[batchSize, channelSize, outputSize]
avgPool1d :: forall (kernelSize :: Nat) (stride :: Nat) (padding :: Nat)
(channelSize :: Nat) (inputSize :: Nat) (batchSize :: Nat)
(outputSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(All
KnownNat
'[kernelSize, stride, padding, channelSize, inputSize, batchSize],
ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
avgPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.avg_pool1d_tlllbb
Tensor device dtype '[batchSize, channelSize, inputSize]
input
(forall (n :: Nat). KnownNat n => Int
natValI @kernelSize)
(forall (n :: Nat). KnownNat n => Int
natValI @stride)
(forall (n :: Nat). KnownNat n => Int
natValI @padding)
Bool
False
Bool
True
adaptiveAvgPool1d ::
forall outputSize channelSize inputSize batchSize dtype device.
(All KnownNat '[channelSize, inputSize, batchSize, outputSize]) =>
Tensor device dtype '[batchSize, channelSize, inputSize] ->
Tensor device dtype '[batchSize, channelSize, outputSize]
adaptiveAvgPool1d :: forall (outputSize :: Nat) (channelSize :: Nat) (inputSize :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All KnownNat '[channelSize, inputSize, batchSize, outputSize] =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
adaptiveAvgPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool1d_tl Tensor device dtype '[batchSize, channelSize, inputSize]
input (forall (n :: Nat). KnownNat n => Int
natValI @outputSize)
adaptiveMaxPool1d ::
forall outputSize channelSize inputSize batchSize dtype device.
(All KnownNat '[channelSize, inputSize, batchSize, outputSize]) =>
Tensor device dtype '[batchSize, channelSize, inputSize] ->
( Tensor device dtype '[batchSize, channelSize, outputSize],
Tensor device 'D.Int64 '[batchSize, channelSize, outputSize]
)
adaptiveMaxPool1d :: forall (outputSize :: Nat) (channelSize :: Nat) (inputSize :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All KnownNat '[channelSize, inputSize, batchSize, outputSize] =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> (Tensor device dtype '[batchSize, channelSize, outputSize],
Tensor device 'Int64 '[batchSize, channelSize, outputSize])
adaptiveMaxPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.adaptive_max_pool1d_tl Tensor device dtype '[batchSize, channelSize, inputSize]
input (forall (n :: Nat). KnownNat n => Int
natValI @outputSize)
addmv ::
forall shape' shape n m dtype device.
( KnownNat n,
KnownNat m,
shape' ~ Broadcast shape '[n]
) =>
Float ->
Float ->
Tensor device dtype '[n, m] ->
Tensor device dtype '[m] ->
Tensor device dtype shape ->
Tensor device dtype shape'
addmv :: forall (shape' :: [Nat]) (shape :: [Nat]) (n :: Nat) (m :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Broadcast shape '[n]) =>
Float
-> Float
-> Tensor device dtype '[n, m]
-> Tensor device dtype '[m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
addmv Float
beta Float
alpha Tensor device dtype '[n, m]
mat Tensor device dtype '[m]
vec Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.addmv_tttss Tensor device dtype shape
input Tensor device dtype '[n, m]
mat Tensor device dtype '[m]
vec Float
beta Float
alpha
allclose ::
forall shape dtype device.
Double ->
Double ->
Bool ->
Tensor device dtype shape ->
Tensor device dtype shape ->
Bool
allclose :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double
-> Double
-> Bool
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Bool
allclose Double
rtol Double
atol Bool
equalNaN Tensor device dtype shape
input Tensor device dtype shape
other =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor -> CDouble -> CDouble -> CBool -> IO CBool
ATen.Managed.allclose_ttddb Tensor device dtype shape
input Tensor device dtype shape
other Double
rtol Double
atol Bool
equalNaN
argmax ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype
) =>
Tensor device dtype shape ->
Tensor device 'D.Int64 shape'
argmax :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device 'Int64 shape'
argmax Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argmax_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
argmin ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype
) =>
Tensor device dtype shape ->
Tensor device 'D.Int64 shape'
argmin :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim,
StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device 'Int64 shape'
argmin Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argmin_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
asin ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
asin :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
asin Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.asin_t Tensor device dtype shape
input
atan ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
atan :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
atan Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.atan_t Tensor device dtype shape
input
baddbmm ::
forall shape' shape batchSize n m k dtype device.
( KnownNat n,
KnownNat m,
KnownNat k,
shape' ~ Broadcast shape '[batchSize, n, m]
) =>
Float ->
Float ->
Tensor device dtype '[batchSize, n, k] ->
Tensor device dtype '[batchSize, k, m] ->
Tensor device dtype shape ->
Tensor device dtype shape'
baddbmm :: forall (shape' :: [Nat]) (shape :: [Nat]) (batchSize :: Nat)
(n :: Nat) (m :: Nat) (k :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, KnownNat k,
shape' ~ Broadcast shape '[batchSize, n, m]) =>
Float
-> Float
-> Tensor device dtype '[batchSize, n, k]
-> Tensor device dtype '[batchSize, k, m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
baddbmm Float
beta Float
alpha Tensor device dtype '[batchSize, n, k]
batch1 Tensor device dtype '[batchSize, k, m]
batch2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.baddbmm_tttss Tensor device dtype shape
input Tensor device dtype '[batchSize, n, k]
batch1 Tensor device dtype '[batchSize, k, m]
batch2 Float
beta Float
alpha
bmm ::
forall batchSize n m k dtype device.
Tensor device dtype '[batchSize, n, k] ->
Tensor device dtype '[batchSize, k, m] ->
Tensor device dtype '[batchSize, n, m]
bmm :: forall (batchSize :: Nat) (n :: Nat) (m :: Nat) (k :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[batchSize, n, k]
-> Tensor device dtype '[batchSize, k, m]
-> Tensor device dtype '[batchSize, n, m]
bmm Tensor device dtype '[batchSize, n, k]
input Tensor device dtype '[batchSize, k, m]
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.bmm_tt Tensor device dtype '[batchSize, n, k]
input Tensor device dtype '[batchSize, k, m]
other
type family BroadcastTensorsImpl (tensors :: [a]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
BroadcastTensorsImpl '[] 'Nothing = 'Nothing
BroadcastTensorsImpl '[] ('Just '(reverseShape, dtype, device)) = 'Just '(Reverse reverseShape, dtype, device)
BroadcastTensorsImpl (Tensor device dtype shape ': tensors) 'Nothing = BroadcastTensorsImpl tensors ('Just '(Reverse shape, dtype, device))
BroadcastTensorsImpl (Tensor device dtype shape ': tensors) ('Just '(reverseShape', dtype, device)) = BroadcastTensorsImpl tensors (MaybeTriple (ComputeBroadcast (Reverse shape) reverseShape') ('Just dtype) ('Just device))
BroadcastTensorsImpl (Tensor device dtype shape ': _) ('Just _) = Nothing
type family BroadcastTensorsCheck (tensors :: [a]) (result :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: [a] where
BroadcastTensorsCheck tensors 'Nothing =
TypeError
( Text "Cannot broadcast tensors due to incompatible shapes and/or dtypes: "
:<>: ShowType tensors
)
BroadcastTensorsCheck tensors ('Just '(shape, dtype, device)) = HReplicateR (ListLength tensors) (Tensor device dtype shape)
type BroadcastTensors tensors =
BroadcastTensorsCheck tensors (BroadcastTensorsImpl tensors 'Nothing)
broadcastTensors ::
forall tensors tensors'.
( tensors' ~ BroadcastTensors tensors,
ATen.Castable (HList tensors) [D.ATenTensor],
ATen.Castable (HList tensors') [D.ATenTensor]
) =>
HList tensors ->
HList tensors'
broadcastTensors :: forall {k} (tensors :: [k]) (tensors' :: [k]).
(tensors' ~ BroadcastTensors tensors,
Castable (HList tensors) [ForeignPtr Tensor],
Castable (HList tensors') [ForeignPtr Tensor]) =>
HList tensors -> HList tensors'
broadcastTensors HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr TensorList -> IO (ForeignPtr TensorList)
ATen.Managed.broadcast_tensors_l HList tensors
tensors
type family CatImpl (dim :: Nat) (tensors :: [a]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
CatImpl _ '[] acc = acc
CatImpl dim (Tensor device dtype shape ': tensors) acc = CatImpl dim tensors (MaybeTriple (ComputeCatShape dim shape acc) (ComputeCatDType dtype acc) (ComputeCatDevice device acc))
type family ComputeCatShape (dim :: Nat) (shape :: [Nat]) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe [Nat] where
ComputeCatShape 0 (x ': xs) Nothing = Just (x ': xs)
ComputeCatShape dim (x ': xs) Nothing = AppendToMaybe x (ComputeCatShape (dim - 1) xs Nothing)
ComputeCatShape 0 (x ': xs) (Just '(y ': xs, _, _)) = Just ((x + y) ': xs)
ComputeCatShape dim (x ': xs) (Just '(x ': ys, dtype, device)) = AppendToMaybe x (ComputeCatShape (dim - 1) xs (Just '(ys, dtype, device)))
ComputeCatShape _ _ _ = Nothing
type family ComputeCatDType (dtype :: D.DType) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe D.DType where
ComputeCatDType dtype Nothing = Just dtype
ComputeCatDType dtype (Just '(_, dtype, _)) = Just dtype
ComputeCatDType _ _ = Nothing
type family ComputeCatDevice (device :: (D.DeviceType, Nat)) (acc :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: Maybe (D.DeviceType, Nat) where
ComputeCatDevice device Nothing = Just device
ComputeCatDevice device (Just '(_, _, device)) = Just device
ComputeCatDevice _ _ = Nothing
type family CatCheck (res :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: ([Nat], D.DType, (D.DeviceType, Nat)) where
CatCheck 'Nothing = TypeError (Text "Concatenation impossible.")
CatCheck ('Just '(shape, dtype, device)) = '(shape, dtype, device)
type Cat dim tensors = CatCheck (CatImpl dim tensors Nothing)
cat ::
forall dim shape dtype device tensors.
( KnownNat dim,
'(shape, dtype, device) ~ Cat dim tensors,
ATen.Castable (HList tensors) [D.ATenTensor]
) =>
HList tensors ->
Tensor device dtype shape
cat :: forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
cat HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.cat_ll HList tensors
tensors (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)
type family ChunkImpl (chunkShapes :: Maybe [[Nat]]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) :: Maybe a where
ChunkImpl (Just '[]) _ _ = Just '[]
ChunkImpl (Just (shape ': shapes)) dtype device = AppendToMaybe (Tensor device dtype shape) (ChunkImpl (Just shapes) dtype device)
ChunkImpl Nothing _ _ = Nothing
type family ChunkCheck (shape :: [Nat]) (dim :: Nat) (result :: Maybe a) :: a where
ChunkCheck shape dim Nothing = DimOutOfBound shape dim
ChunkCheck _ _ (Just result) = result
type family ComputeChunksChunkGo (n' :: Nat) (r :: Nat) (cmp :: Ordering) (cmp' :: Ordering) :: [Nat] where
ComputeChunksChunkGo n' r GT _ = n' ': ComputeChunksChunkGo n' (r - n') (CmpNat (r - n') n') (CmpNat (r - n') 0)
ComputeChunksChunkGo n' r EQ _ = n' ': ComputeChunksChunkGo n' (r - n') (CmpNat (r - n') n') (CmpNat (r - n') 0)
ComputeChunksChunkGo n' r _ GT = '[r]
ComputeChunksChunkGo n' _ _ _ = '[]
type family ComputeChunksChunkGo0 (n' :: Nat) (chunks :: Nat) :: [Nat] where
ComputeChunksChunkGo0 _ 0 = '[]
ComputeChunksChunkGo0 n' chunks = n' ': (ComputeChunksChunkGo0 n' (chunks - 1))
type family ComputeChunks (n :: Maybe Nat) (chunks :: Nat) :: Maybe [Nat] where
ComputeChunks (Just n) chunks = Just (ComputeChunks' n chunks (Mod n chunks))
ComputeChunks Nothing _ = Nothing
type family ComputeChunks' (n :: Nat) (chunks :: Nat) (m :: Nat) :: [Nat] where
ComputeChunks' n chunks 0 = ComputeChunksChunkGo0 (Div n chunks) chunks
ComputeChunks' n chunks _ = ComputeChunksChunkGo (Div (n + chunks - 1) chunks) n (CmpNat n (Div (n + chunks - 1) chunks)) (CmpNat n 0)
type family ChunkShapesImpl (chunks :: Maybe [Nat]) (dim :: Nat) (shape :: [Nat]) :: Maybe [[Nat]] where
ChunkShapesImpl (Just (n ': ns)) dim shape = AppendToMaybe' (ReplaceDim dim shape n) (ChunkShapesImpl (Just ns) dim shape)
ChunkShapesImpl (Just '[]) _ _ = Just '[]
ChunkShapesImpl Nothing _ _ = Nothing
type ChunkShapes chunks dim shape = ChunkShapesImpl (ComputeChunks (ExtractDim dim shape) chunks) dim shape
type Chunk chunks dim shape dtype device = ChunkCheck shape dim (ChunkImpl (ChunkShapes chunks dim shape) dtype device)
chunk ::
forall chunks dim shape dtype device tensorChunks.
( KnownNat chunks,
KnownNat dim,
tensorChunks ~ Chunk chunks dim shape dtype device,
ATen.Castable (HList tensorChunks) [D.ATenTensor]
) =>
Tensor device dtype shape ->
HList tensorChunks
chunk :: forall {k} (chunks :: Nat) (dim :: Nat) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat))
(tensorChunks :: [k]).
(KnownNat chunks, KnownNat dim,
tensorChunks ~ Chunk chunks dim shape dtype device,
Castable (HList tensorChunks) [ForeignPtr Tensor]) =>
Tensor device dtype shape -> HList tensorChunks
chunk Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr TensorList)
ATen.Managed.chunk_tll Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @chunks :: Int) (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)
clamp ::
forall shape dtype device a.
(D.Scalar a) =>
a ->
a ->
Tensor device dtype shape ->
Tensor device dtype shape
clamp :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) a.
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
clamp a
min a
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_tss Tensor device dtype shape
input a
min a
max
clampMax ::
forall shape dtype device a.
(D.Scalar a) =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
clampMax :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) a.
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
clampMax a
max Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_max_ts Tensor device dtype shape
input a
max
clampMin ::
forall shape dtype device a.
(D.Scalar a) =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
clampMin :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) a.
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
clampMin a
min Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.clamp_min_ts Tensor device dtype shape
input a
min
cudnnIsAcceptable ::
forall shape dtype device.
Tensor device dtype shape ->
Bool
cudnnIsAcceptable :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.cudnn_is_acceptable_t Tensor device dtype shape
input
constantPadNd1d ::
forall (pad :: (Nat, Nat)) n dtype device.
(All KnownNat '[Torch.Typed.Auxiliary.Fst pad, Torch.Typed.Auxiliary.Snd pad, n]) =>
Float ->
Tensor device dtype '[n] ->
Tensor device dtype '[n + Torch.Typed.Auxiliary.Fst pad + Torch.Typed.Auxiliary.Snd pad]
constantPadNd1d :: forall (pad :: (Nat, Nat)) (n :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
All KnownNat '[Fst pad, Snd pad, n] =>
Float
-> Tensor device dtype '[n]
-> Tensor device dtype '[(n + Fst pad) + Snd pad]
constantPadNd1d Float
value Tensor device dtype '[n]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.constant_pad_nd_tls
Tensor device dtype '[n]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst pad), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd pad)] :: [Int])
Float
value
type ConvSideCheck (inputSize :: Nat) (kernelSize :: Nat) (stride :: Nat) (padding :: Nat) (outputSize :: Nat) =
(
1 <= kernelSize,
1 <= stride,
(kernelSize - 1) <= (inputSize + (2 * padding)),
1 <= outputSize,
outputSize ~ ConvOutputSize inputSize kernelSize stride padding
)
type family ConvOutputSize (inputSize :: Nat) (kernelSize :: Nat) (stride :: Nat) (padding :: Nat) :: Nat where
ConvOutputSize inputSize kernelSize stride padding = (Div ((inputSize + (2 * padding)) - kernelSize) stride) + 1
conv1d ::
forall
(stride :: Nat)
(padding :: Nat)
inputChannelSize
outputChannelSize
kernelSize
inputSize
batchSize
outputSize
dtype
device.
( All
KnownNat
'[ stride,
padding,
inputChannelSize,
outputChannelSize,
kernelSize,
inputSize,
batchSize,
outputSize
],
ConvSideCheck inputSize kernelSize stride padding outputSize
) =>
Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize] ->
Tensor device dtype '[outputChannelSize] ->
Tensor device dtype '[batchSize, inputChannelSize, inputSize] ->
Tensor device dtype '[batchSize, outputChannelSize, outputSize]
conv1d :: forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
(outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
(batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
inputSize, batchSize, outputSize],
ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
device dtype '[outputChannelSize, inputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
conv1d Tensor
device dtype '[outputChannelSize, inputChannelSize, kernelSize]
weight Tensor device dtype '[outputChannelSize]
bias Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv1d_tttllll
Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
Tensor
device dtype '[outputChannelSize, inputChannelSize, kernelSize]
weight
Tensor device dtype '[outputChannelSize]
bias
(forall (n :: Nat). KnownNat n => Int
natValI @stride :: Int)
(forall (n :: Nat). KnownNat n => Int
natValI @padding :: Int)
(Int
1 :: Int)
(Int
1 :: Int)
conv2d ::
forall
(stride :: (Nat, Nat))
(padding :: (Nat, Nat))
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
inputSize0
inputSize1
batchSize
outputSize0
outputSize1
dtype
device.
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
inputSize0,
inputSize1,
batchSize,
outputSize0,
outputSize1
],
ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1] ->
Tensor device dtype '[outputChannelSize] ->
Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2d :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst stride, Snd stride, Fst padding, Snd padding,
inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2d Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv2d_tttllll
Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight
Tensor device dtype '[outputChannelSize]
bias
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
([Int
1, Int
1] :: [Int])
(Int
1 :: Int)
conv3d ::
forall
(stride :: (Nat, Nat, Nat))
(padding :: (Nat, Nat, Nat))
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
inputSize0
inputSize1
inputSize2
batchSize
outputSize0
outputSize1
outputSize2
dtype
device.
( All
KnownNat
'[ Fst3 stride,
Snd3 stride,
Trd3 stride,
Fst3 padding,
Snd3 padding,
Trd3 padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
kernelSize2,
inputSize0,
inputSize1,
inputSize2,
batchSize
],
ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
) =>
Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1, kernelSize2] ->
Tensor device dtype '[outputChannelSize] ->
Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2] ->
Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2]
conv3d :: forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
(outputSize2 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
inputSize2, batchSize],
ConvSideCheck
inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck
inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2]
conv3d Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv3d_tttllll
Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight
Tensor device dtype '[outputChannelSize]
bias
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
([Int
1, Int
1, Int
1] :: [Int])
(Int
1 :: Int)
convTBC ::
forall padding timeSize batchSize kernelSize inputChannels outputChannels dtype device.
(KnownNat padding) =>
Tensor device dtype '[kernelSize, inputChannels, outputChannels] ->
Tensor device dtype '[outputChannels] ->
Tensor device dtype '[timeSize, batchSize, inputChannels] ->
Tensor device dtype '[timeSize + padding * 2 + 1 - kernelSize, batchSize, outputChannels]
convTBC :: forall (padding :: Nat) (timeSize :: Nat) (batchSize :: Nat)
(kernelSize :: Nat) (inputChannels :: Nat) (outputChannels :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
KnownNat padding =>
Tensor device dtype '[kernelSize, inputChannels, outputChannels]
-> Tensor device dtype '[outputChannels]
-> Tensor device dtype '[timeSize, batchSize, inputChannels]
-> Tensor
device
dtype
'[((timeSize + (padding * 2)) + 1) - kernelSize, batchSize,
outputChannels]
convTBC Tensor device dtype '[kernelSize, inputChannels, outputChannels]
weight Tensor device dtype '[outputChannels]
bias Tensor device dtype '[timeSize, batchSize, inputChannels]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_tbc_tttl Tensor device dtype '[timeSize, batchSize, inputChannels]
input Tensor device dtype '[kernelSize, inputChannels, outputChannels]
weight Tensor device dtype '[outputChannels]
bias (forall (n :: Nat). KnownNat n => Int
natValI @padding)
convTranspose1d ::
forall
(stride :: Nat)
(padding :: Nat)
inputChannelSize
outputChannelSize
kernelSize
inputSize
batchSize
outputSize
dtype
device.
( All
KnownNat
'[ stride,
padding,
inputChannelSize,
outputChannelSize,
kernelSize,
inputSize,
batchSize,
outputSize
],
ConvSideCheck inputSize kernelSize stride padding outputSize
) =>
Tensor device dtype '[inputChannelSize, outputChannelSize, kernelSize] ->
Tensor device dtype '[outputChannelSize] ->
Tensor device dtype '[batchSize, inputChannelSize, inputSize] ->
Tensor device dtype '[batchSize, outputChannelSize, outputSize]
convTranspose1d :: forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
(outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
(batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
inputSize, batchSize, outputSize],
ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
device dtype '[inputChannelSize, outputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
convTranspose1d Tensor
device dtype '[inputChannelSize, outputChannelSize, kernelSize]
weight Tensor device dtype '[outputChannelSize]
bias Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_transpose1d_tttllll
Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
Tensor
device dtype '[inputChannelSize, outputChannelSize, kernelSize]
weight
Tensor device dtype '[outputChannelSize]
bias
(forall (n :: Nat). KnownNat n => Int
natValI @stride :: Int)
(forall (n :: Nat). KnownNat n => Int
natValI @padding :: Int)
(Int
0 :: Int)
(Int
1 :: Int)
convTranspose2d ::
forall
(stride :: (Nat, Nat))
(padding :: (Nat, Nat))
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
inputSize0
inputSize1
batchSize
outputSize0
outputSize1
dtype
device.
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
inputSize0,
inputSize1,
batchSize,
outputSize0,
outputSize1
],
ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
Tensor device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1] ->
Tensor device dtype '[outputChannelSize] ->
Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
convTranspose2d :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst stride, Snd stride, Fst padding, Snd padding,
inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
convTranspose2d Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_transpose2d_tttllll
Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight
Tensor device dtype '[outputChannelSize]
bias
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
([Int
0, Int
0] :: [Int])
(Int
1 :: Int)
convTranspose3d ::
forall
(stride :: (Nat, Nat, Nat))
(padding :: (Nat, Nat, Nat))
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
inputSize0
inputSize1
inputSize2
batchSize
outputSize0
outputSize1
outputSize2
dtype
device.
( All
KnownNat
'[ Fst3 stride,
Snd3 stride,
Trd3 stride,
Fst3 padding,
Snd3 padding,
Trd3 padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
kernelSize2,
inputSize0,
inputSize1,
inputSize2,
batchSize
],
ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
) =>
Tensor device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1, kernelSize2] ->
Tensor device dtype '[outputChannelSize] ->
Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2] ->
Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2]
convTranspose3d :: forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
(outputSize2 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
inputSize2, batchSize],
ConvSideCheck
inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck
inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2]
convTranspose3d Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv_transpose3d_tttllll
Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight
Tensor device dtype '[outputChannelSize]
bias
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
([Int
0, Int
0, Int
0] :: [Int])
(Int
1 :: Int)
cosh ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
cosh :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
cosh Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.cosh_t Tensor device dtype shape
input
type family Det (shape :: [Nat]) :: [Nat] where
Det (n : n : '[]) = '[]
Det (b : n : n : '[]) = '[b]
Det _ = TypeError (Text "This shape must be square matrix or batch + squre matrix.")
det ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype (Det shape)
det :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype (Det shape)
det Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.det_t Tensor device dtype shape
input
type family DimsDistinctAscendingCheck (dim1 :: Nat) (dim2 :: Nat) (cmp :: Ordering) :: Constraint where
DimsDistinctAscendingCheck _ _ 'LT = ()
DimsDistinctAscendingCheck dim1 dim2 _ =
TypeError
( Text "Dimensions must be distinct and in ascending order, but got "
:<>: ShowType dim1
:<>: Text ", "
:<>: ShowType dim2
)
type family DimsDistinctAscending (dim1 :: Nat) (dim2 :: Nat) :: Constraint where
DimsDistinctAscending dim1 dim2 = DimsDistinctAscendingCheck dim1 dim2 (CmpNat dim1 dim2)
type family DiagEmbedShapeImpl (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat]) (n :: Nat) :: [Nat] where
DiagEmbedShapeImpl dim1 dim2 shape n = Insert dim1 n (Insert (dim2 - 1) n (Init shape))
type family DiagEmbedShape (index :: Nat) (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat]) :: [Nat] where
DiagEmbedShape index dim1 dim2 shape = DiagEmbedShapeImpl dim1 dim2 shape (Last shape + index)
diagEmbed ::
forall index dim1 dim2 shape shape' device dtype.
( KnownNat index,
KnownNat dim1,
KnownNat dim2,
shape' ~ DiagEmbedShape index dim1 dim2 shape,
DimsDistinctAscending dim1 dim2,
StandardDTypeValidation device dtype
) =>
Tri ->
Tensor device dtype shape ->
Tensor device dtype shape'
diagEmbed :: forall (index :: Nat) (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat])
(shape' :: [Nat]) (device :: (DeviceType, Nat)) (dtype :: DType).
(KnownNat index, KnownNat dim1, KnownNat dim2,
shape' ~ DiagEmbedShape index dim1 dim2 shape,
DimsDistinctAscending dim1 dim2,
StandardDTypeValidation device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
diagEmbed Tri
tri Tensor device dtype shape
t =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4
ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.diag_embed_tlll
Tensor device dtype shape
t
(if Tri -> Bool
isUpper Tri
tri then forall (n :: Nat). KnownNat n => Int
natValI @index else - forall (n :: Nat). KnownNat n => Int
natValI @index)
(forall (n :: Nat). KnownNat n => Int
natValI @dim1)
(forall (n :: Nat). KnownNat n => Int
natValI @dim2)
type family DiagflatShapeImpl (d :: Nat) :: [Nat] where
DiagflatShapeImpl d = '[d, d]
type family DiagflatShape (index :: Nat) (shape :: [Nat]) :: [Nat] where
DiagflatShape index shape = DiagflatShapeImpl (Numel shape + index)
diagflat ::
forall index shape shape' device dtype.
( KnownNat index,
shape' ~ DiagflatShape index shape,
StandardDTypeValidation device dtype
) =>
Tri ->
Tensor device dtype shape ->
Tensor device dtype shape'
diagflat :: forall (index :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(device :: (DeviceType, Nat)) (dtype :: DType).
(KnownNat index, shape' ~ DiagflatShape index shape,
StandardDTypeValidation device dtype) =>
Tri -> Tensor device dtype shape -> Tensor device dtype shape'
diagflat Tri
tri Tensor device dtype shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.diagflat_tl Tensor device dtype shape
t forall a b. (a -> b) -> a -> b
$
case Tri
tri of
Tri
Upper -> forall (n :: Nat). KnownNat n => Int
natValI @index
Tri
Lower -> - forall (n :: Nat). KnownNat n => Int
natValI @index
type family NDimAtLeastCheck (ndim :: Nat) (shape :: [Nat]) (cmp :: Ordering) :: Constraint where
NDimAtLeastCheck ndim shape 'GT =
TypeError
( Text "Input must have at least "
:<>: ShowType ndim
:<>: Text " dimensions, but got "
:<>: ShowType (ListLength shape)
)
NDimAtLeastCheck _ _ _ = ()
type family NDimAtLeast (ndim :: Nat) (shape :: [Nat]) :: Constraint where
NDimAtLeast ndim shape = NDimAtLeastCheck ndim shape (CmpNat ndim (ListLength shape))
type family DiagonalShape (tri :: Tri) (index :: Nat) (dim1 :: Nat) (dim2 :: Nat) (shape :: [Nat]) :: [Nat] where
DiagonalShape tri index dim1 dim2 shape =
Remove (Remove shape dim2) dim1 ++ '[DiagSize tri index (Index shape dim1) (Index shape dim2)]
diagonal ::
forall tri index dim1 dim2 shape shape' device dtype.
( KnownTri tri,
KnownNat index,
KnownNat dim1,
KnownNat dim2,
NDimAtLeast 2 shape,
DimsDistinctAscending dim1 dim2,
shape' ~ DiagonalShape tri index dim1 dim2 shape,
StandardDTypeValidation device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
diagonal :: forall (tri :: Tri) (index :: Nat) (dim1 :: Nat) (dim2 :: Nat)
(shape :: [Nat]) (shape' :: [Nat]) (device :: (DeviceType, Nat))
(dtype :: DType).
(KnownTri tri, KnownNat index, KnownNat dim1, KnownNat dim2,
NDimAtLeast 2 shape, DimsDistinctAscending dim1 dim2,
shape' ~ DiagonalShape tri index dim1 dim2 shape,
StandardDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape'
diagonal Tensor device dtype shape
t =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4
ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.diagonal_tlll
Tensor device dtype shape
t
(if Tri -> Bool
isUpper (forall (tri :: Tri). KnownTri tri => Tri
triVal @tri) then forall (n :: Nat). KnownNat n => Int
natValI @index else - forall (n :: Nat). KnownNat n => Int
natValI @index)
(forall (n :: Nat). KnownNat n => Int
natValI @dim1)
(forall (n :: Nat). KnownNat n => Int
natValI @dim2)
type family DotDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
DotDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
DotDTypeIsValid '( 'D.CUDA, deviceIndex) dtype = DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype
DotDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
dot ::
forall size dtype device.
DotDTypeIsValid device dtype =>
Tensor device dtype '[size] ->
Tensor device dtype '[size] ->
Tensor device dtype '[]
dot :: forall (size :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
DotDTypeIsValid device dtype =>
Tensor device dtype '[size]
-> Tensor device dtype '[size] -> Tensor device dtype '[]
dot Tensor device dtype '[size]
input Tensor device dtype '[size]
other = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.dot_tt Tensor device dtype '[size]
input Tensor device dtype '[size]
other
class KnownMaybeNat (n :: Maybe Nat) where
maybeNatVal :: Maybe Integer
instance (KnownNat n) => KnownMaybeNat (Just n) where
maybeNatVal :: Maybe Integer
maybeNatVal = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @n
instance KnownMaybeNat Nothing where
maybeNatVal :: Maybe Integer
maybeNatVal = forall a. Maybe a
Nothing
type family PaddingIdxCheck (idx :: Maybe Nat) (numEmbeds :: Nat) :: Constraint where
PaddingIdxCheck (Just n) numEmbeds = n + 1 <= numEmbeds
PaddingIdxCheck Nothing _ = ()
embedding ::
forall (paddingIdx :: Maybe Nat) numEmbeds embedDim shape dtype device.
( KnownMaybeNat paddingIdx,
PaddingIdxCheck paddingIdx numEmbeds
) =>
Bool ->
Bool ->
Tensor device dtype '[numEmbeds, embedDim] ->
Tensor device 'D.Int64 shape ->
Tensor device dtype (Reverse (embedDim ': Reverse shape))
embedding :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding Bool
scaleGradByFreq Bool
sparse Tensor device dtype '[numEmbeds, embedDim]
weights Tensor device 'Int64 shape
indices =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.embedding_ttlbb Tensor device dtype '[numEmbeds, embedDim]
weights Tensor device 'Int64 shape
indices Int
paddingIdx Bool
scaleGradByFreq Bool
sparse
where
paddingIdx :: Int
paddingIdx :: Int
paddingIdx = case forall (n :: Maybe Nat). KnownMaybeNat n => Maybe Integer
maybeNatVal @paddingIdx of
Just Integer
idx -> forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
idx
Maybe Integer
Nothing -> -Int
1
emptyLike ::
forall shape dtype device.
Tensor device dtype shape ->
IO (Tensor device dtype shape)
emptyLike :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
emptyLike Tensor device dtype shape
input = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.empty_like_t Tensor device dtype shape
input
erfc ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
erfc :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
erfc Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.erfc_t Tensor device dtype shape
input
expm1 ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
expm1 :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
expm1 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.expm1_t Tensor device dtype shape
input
expand ::
forall shape' shape dtype device.
( KnownShape shape',
shape' ~ Broadcast shape shape'
) =>
Bool ->
Tensor device dtype shape ->
Tensor device dtype shape'
expand :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand Bool
someBool Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.tensor_expand_lb Tensor device dtype shape
input (forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape') Bool
someBool
flattenAll ::
forall shape dtype device.
KnownShape shape =>
Tensor device dtype shape ->
Tensor device dtype '[Product shape]
flattenAll :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownShape shape =>
Tensor device dtype shape -> Tensor device dtype '[Product shape]
flattenAll Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.flatten_tll Tensor device dtype shape
input (Int
0 :: Int) (-Int
1 :: Int)
frac ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
frac :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
frac Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.frac_t Tensor device dtype shape
input
fullLike ::
forall shape dtype device.
Float ->
Tensor device dtype shape ->
Tensor device dtype shape
fullLike :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Float -> Tensor device dtype shape -> Tensor device dtype shape
fullLike Float
fillValue Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.full_like_ts Tensor device dtype shape
input Float
fillValue
isclose ::
forall shape dtype device.
Double ->
Double ->
Bool ->
Tensor device dtype shape ->
Tensor device dtype shape ->
Tensor device 'D.Bool shape
isclose :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double
-> Double
-> Bool
-> Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device 'Bool shape
isclose Double
rtol Double
atol Bool
equalNaN Tensor device dtype shape
input Tensor device dtype shape
other =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> CDouble
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.isclose_ttddb Tensor device dtype shape
input Tensor device dtype shape
other Double
rtol Double
atol Bool
equalNaN
isNaN ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device 'D.Bool shape
isNaN :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device 'Bool shape
isNaN Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.isnan_t Tensor device dtype shape
input
isDistributed ::
forall shape dtype device.
Tensor device dtype shape ->
Bool
isDistributed :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isDistributed Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_distributed_t Tensor device dtype shape
input
isFloatingPoint ::
forall shape dtype device.
Tensor device dtype shape ->
Bool
isFloatingPoint :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isFloatingPoint Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_floating_point_t Tensor device dtype shape
input
isComplex ::
forall shape dtype device.
Tensor device dtype shape ->
Bool
isComplex :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isComplex Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_complex_t Tensor device dtype shape
input
isNonZero ::
forall shape dtype device.
(Numel shape ~ 1) =>
Tensor device dtype shape ->
Bool
isNonZero :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(Numel shape ~ 1) =>
Tensor device dtype shape -> Bool
isNonZero Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_nonzero_t Tensor device dtype shape
input
isSameSize ::
forall shape shape' dtype device.
Tensor device dtype shape ->
Tensor device dtype shape' ->
Bool
isSameSize :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape' -> Bool
isSameSize Tensor device dtype shape
input Tensor device dtype shape'
other =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO CBool
ATen.Managed.is_same_size_tt Tensor device dtype shape
input Tensor device dtype shape'
other
isSigned ::
forall shape dtype device.
Tensor device dtype shape ->
Bool
isSigned :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
isSigned Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CBool
ATen.Managed.is_signed_t Tensor device dtype shape
input
layerNorm ::
forall normalizedShape shape dtype device.
( KnownShape normalizedShape,
IsSuffixOf normalizedShape shape
) =>
Tensor device dtype normalizedShape ->
Tensor device dtype normalizedShape ->
Double ->
Tensor device dtype shape ->
Tensor device dtype shape
layerNorm :: forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape normalizedShape, IsSuffixOf normalizedShape shape) =>
Tensor device dtype normalizedShape
-> Tensor device dtype normalizedShape
-> Double
-> Tensor device dtype shape
-> Tensor device dtype shape
layerNorm Tensor device dtype normalizedShape
weight Tensor device dtype normalizedShape
bias Double
eps Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> CDouble
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.layer_norm_tlttdb
Tensor device dtype shape
input
(forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @normalizedShape)
Tensor device dtype normalizedShape
weight
Tensor device dtype normalizedShape
bias
Double
eps
( forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype normalizedShape
weight
Bool -> Bool -> Bool
&& forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype normalizedShape
bias
Bool -> Bool -> Bool
&& forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Bool
cudnnIsAcceptable Tensor device dtype shape
input
)
linear ::
forall batchSize inputFeatures outputFeatures dtype device.
Tensor device dtype '[outputFeatures, inputFeatures] ->
Tensor device dtype '[outputFeatures] ->
Tensor device dtype '[batchSize, inputFeatures] ->
Tensor device dtype '[batchSize, outputFeatures]
linear :: forall (batchSize :: Nat) (inputFeatures :: Nat)
(outputFeatures :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype '[batchSize, inputFeatures]
-> Tensor device dtype '[batchSize, outputFeatures]
linear Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias Tensor device dtype '[batchSize, inputFeatures]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.linear_ttt Tensor device dtype '[batchSize, inputFeatures]
input Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias
linear' ::
forall (inputFeatures :: Nat) (outputFeatures :: Nat) (shape :: [Nat]) (shape' :: [Nat]) dtype device (shape'' :: [Nat]).
( shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
shape' ~ Broadcast shape'' shape''
) =>
Tensor device dtype '[outputFeatures, inputFeatures] ->
Tensor device dtype '[outputFeatures] ->
Tensor device dtype shape ->
Tensor device dtype shape'
linear' :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape'' :: [Nat]).
(shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
shape' ~ Broadcast shape'' shape'') =>
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype shape
-> Tensor device dtype shape'
linear' Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.linear_ttt Tensor device dtype shape
input Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias
mkldnnLinear ::
forall batchSize inputFeatures outputFeatures dtype device.
Tensor device dtype '[outputFeatures, inputFeatures] ->
Tensor device dtype '[outputFeatures] ->
Tensor device dtype '[batchSize, inputFeatures] ->
Tensor device dtype '[batchSize, outputFeatures]
mkldnnLinear :: forall (batchSize :: Nat) (inputFeatures :: Nat)
(outputFeatures :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype '[batchSize, inputFeatures]
-> Tensor device dtype '[batchSize, outputFeatures]
mkldnnLinear Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias Tensor device dtype '[batchSize, inputFeatures]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mkldnn_linear_ttt Tensor device dtype '[batchSize, inputFeatures]
input Tensor device dtype '[outputFeatures, inputFeatures]
weight Tensor device dtype '[outputFeatures]
bias
log ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
log :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
log Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log_t Tensor device dtype shape
input
logDet ::
forall shape' shape dtype device.
(shape' ~ Det shape) =>
Tensor device dtype shape ->
Tensor device dtype shape'
logDet :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ Det shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
logDet Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.logdet_t Tensor device dtype shape
input
logSumExp ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
Reifies dtype D.DType,
DTypeIsFloatingPoint device dtype,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
logSumExp :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
Reifies dtype DType, DTypeIsFloatingPoint device dtype,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device dtype shape -> Tensor device dtype shape'
logSumExp Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.logsumexp_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
matrixPower ::
forall shape' shape dtype device.
(shape' ~ Square shape) =>
Int ->
Tensor device dtype shape ->
Tensor device dtype shape'
matrixPower :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ Square shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape'
matrixPower Int
n Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.matrix_power_tl Tensor device dtype shape
input Int
n
maxValues ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
maxValues :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device dtype shape -> Tensor device dtype shape'
maxValues Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.max_values_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
minValues ::
forall dim keepOrDropDim shape' shape dtype device.
( KnownNat dim,
KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
minValues :: forall (dim :: Nat) (keepOrDropDim :: KeepOrDropDim)
(shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownKeepOrDropDim keepOrDropDim,
shape' ~ ConditionalDropDimension shape dim keepOrDropDim) =>
Tensor device dtype shape -> Tensor device dtype shape'
minValues Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.min_values_tlb
Tensor device dtype shape
input
(forall (n :: Nat). KnownNat n => Int
natValI @dim)
(forall {k} (keepOrDropDim :: k).
KnownKeepOrDropDim keepOrDropDim =>
Bool
keepOrDropDimVal @keepOrDropDim)
maxPool1d ::
forall kernelSize stride padding channelSize inputSize batchSize outputSize dtype device.
( All
KnownNat
'[ kernelSize,
stride,
padding,
channelSize,
inputSize,
batchSize
],
ConvSideCheck inputSize kernelSize stride padding outputSize
) =>
Tensor device dtype '[batchSize, channelSize, inputSize] ->
Tensor device dtype '[batchSize, channelSize, outputSize]
maxPool1d :: forall (kernelSize :: Nat) (stride :: Nat) (padding :: Nat)
(channelSize :: Nat) (inputSize :: Nat) (batchSize :: Nat)
(outputSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(All
KnownNat
'[kernelSize, stride, padding, channelSize, inputSize, batchSize],
ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor device dtype '[batchSize, channelSize, inputSize]
-> Tensor device dtype '[batchSize, channelSize, outputSize]
maxPool1d Tensor device dtype '[batchSize, channelSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.max_pool1d_tllllb
Tensor device dtype '[batchSize, channelSize, inputSize]
input
(forall (n :: Nat). KnownNat n => Int
natValI @kernelSize)
(forall (n :: Nat). KnownNat n => Int
natValI @stride)
(forall (n :: Nat). KnownNat n => Int
natValI @padding)
(Int
1 :: Int)
Bool
False
maxPool2d ::
forall kernelSize stride padding channelSize inputSize0 inputSize1 batchSize outputSize0 outputSize1 dtype device.
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst kernelSize,
Torch.Typed.Auxiliary.Snd kernelSize,
Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
channelSize,
inputSize0,
inputSize1,
batchSize
],
ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
maxPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
(padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
batchSize],
ConvSideCheck
inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1
(Snd kernelSize)
(Snd stride)
(Snd padding)
outputSize1) =>
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
device dtype '[batchSize, channelSize, outputSize0, outputSize1]
maxPool2d Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.max_pool2d_tllllb
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
([Int
1, Int
1] :: [Int])
Bool
False
mkldnnMaxPool2d ::
forall kernelSize stride padding channelSize inputSize0 inputSize1 batchSize outputSize0 outputSize1 dtype device.
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst kernelSize,
Torch.Typed.Auxiliary.Snd kernelSize,
Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
channelSize,
inputSize0,
inputSize1,
batchSize
],
ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
mkldnnMaxPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
(padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
batchSize],
ConvSideCheck
inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1
(Snd kernelSize)
(Snd stride)
(Snd padding)
outputSize1) =>
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
device dtype '[batchSize, channelSize, outputSize0, outputSize1]
mkldnnMaxPool2d Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.mkldnn_max_pool2d_tllllb
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
([Int
1, Int
1] :: [Int])
Bool
False
quantizedMaxPool2d ::
forall kernelSize stride padding channelSize inputSize0 inputSize1 batchSize outputSize0 outputSize1 dtype device.
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst kernelSize,
Torch.Typed.Auxiliary.Snd kernelSize,
Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
channelSize,
inputSize0,
inputSize1,
batchSize
],
ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
quantizedMaxPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
(padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
batchSize],
ConvSideCheck
inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1
(Snd kernelSize)
(Snd stride)
(Snd padding)
outputSize1) =>
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
device dtype '[batchSize, channelSize, outputSize0, outputSize1]
quantizedMaxPool2d Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> IO (ForeignPtr Tensor)
ATen.Managed.quantized_max_pool2d_tllll
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
([Int
1, Int
1] :: [Int])
maxPool3d ::
forall
kernelSize
stride
padding
channelSize
inputSize0
inputSize1
inputSize2
batchSize
outputSize0
outputSize1
outputSize2
dtype
device.
( All
KnownNat
'[ Fst3 kernelSize,
Snd3 kernelSize,
Trd3 kernelSize,
Fst3 stride,
Snd3 stride,
Trd3 stride,
Fst3 padding,
Snd3 padding,
Trd3 padding,
channelSize,
inputSize0,
inputSize1,
inputSize2,
batchSize
],
ConvSideCheck inputSize0 (Fst3 kernelSize) (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck inputSize1 (Snd3 kernelSize) (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck inputSize2 (Trd3 kernelSize) (Trd3 stride) (Trd3 padding) outputSize2
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
maxPool3d :: forall (kernelSize :: (Nat, Nat, Nat)) (stride :: (Nat, Nat, Nat))
(padding :: (Nat, Nat, Nat)) (channelSize :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
(outputSize2 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst3 kernelSize, Snd3 kernelSize, Trd3 kernelSize, Fst3 stride,
Snd3 stride, Trd3 stride, Fst3 padding, Snd3 padding, Trd3 padding,
channelSize, inputSize0, inputSize1, inputSize2, batchSize],
ConvSideCheck
inputSize0
(Fst3 kernelSize)
(Fst3 stride)
(Fst3 padding)
outputSize0,
ConvSideCheck
inputSize1
(Snd3 kernelSize)
(Snd3 stride)
(Snd3 padding)
outputSize1,
ConvSideCheck
inputSize2
(Trd3 kernelSize)
(Trd3 stride)
(Trd3 padding)
outputSize2) =>
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
device
dtype
'[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
maxPool3d Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> IO (ForeignPtr Tensor)
ATen.Managed.max_pool3d_tllllb
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 kernelSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 kernelSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 kernelSize)
] ::
[Int]
)
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
([Int
1, Int
1, Int
1] :: [Int])
Bool
False
maskedFill ::
forall a shape shape' shape'' dtype device.
(D.Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'D.Bool shape' ->
a ->
Tensor device dtype shape ->
Tensor device dtype shape''
maskedFill :: forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool shape'
mask a
value Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.masked_fill_tts Tensor device dtype shape
input Tensor device 'Bool shape'
mask a
value
mm ::
forall n k m dtype device.
Tensor device dtype '[n, k] ->
Tensor device dtype '[k, m] ->
Tensor device dtype '[n, m]
mm :: forall (n :: Nat) (k :: Nat) (m :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype '[n, k]
-> Tensor device dtype '[k, m] -> Tensor device dtype '[n, m]
mm Tensor device dtype '[n, k]
a Tensor device dtype '[k, m]
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mm_tt Tensor device dtype '[n, k]
a Tensor device dtype '[k, m]
b
mv ::
forall n m dtype device.
Tensor device dtype '[n, m] ->
Tensor device dtype '[m] ->
Tensor device dtype '[n]
mv :: forall (n :: Nat) (m :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype '[n, m]
-> Tensor device dtype '[m] -> Tensor device dtype '[n]
mv Tensor device dtype '[n, m]
input Tensor device dtype '[m]
vec = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.mv_tt Tensor device dtype '[n, m]
input Tensor device dtype '[m]
vec
type family
NarrowCheck
(mbCurrent :: Maybe Nat)
(mbUpdated :: Maybe [Nat])
(shape :: [Nat])
(dim :: Nat)
(start :: Nat)
(length :: Nat) ::
[Nat]
where
NarrowCheck Nothing _ sh d _ _ = DimOutOfBound sh d
NarrowCheck (Just c) Nothing sh d s l = DimOutOfBound sh d
NarrowCheck _ (Just r) _ _ _ _ = r
type family Narrow' (dim :: Nat) (shape :: [Nat]) (current :: Maybe Nat) (start :: Nat) (length :: Nat) :: Maybe [Nat] where
Narrow' d sh (Just c) s l =
If
((s + l) <=? c)
(ReplaceDim d sh l)
( TypeError
( Text "The end of the requested narrow segment "
:<>: ShowType (s + l)
:<>: Text " would be larger than current size "
:<>: ShowType c
:<>: Text " at dimension "
:<>: ShowType d
)
)
Narrow' d sh Nothing s l =
TypeError
( Text "Requested narrow dimension "
:<>: ShowType d
:<>: Text " doesnt exist in "
:<>: ShowType sh
)
type family Narrow (shape :: [Nat]) (dim :: Nat) (start :: Nat) (length :: Nat) :: [Nat] where
Narrow shape dim start length =
NarrowCheck (ExtractDim dim shape) (Narrow' dim shape (ExtractDim dim shape) start length) shape dim start length
narrow ::
forall dim start length shape mbSize mbNewShape dtype device.
( All KnownNat '[dim, start, length],
All KnownNat shape
) =>
Tensor device dtype shape ->
Tensor device dtype (Narrow shape dim start length)
narrow :: forall {k} {k} (dim :: Nat) (start :: Nat) (length :: Nat)
(shape :: [Nat]) (mbSize :: k) (mbNewShape :: k) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[dim, start, length], All KnownNat shape) =>
Tensor device dtype shape
-> Tensor device dtype (Narrow shape dim start length)
narrow Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.narrow_tlll) Tensor device dtype shape
_input Int
_dim Int
_start Int
_length
where
_dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim
_start :: Int
_start = forall (n :: Nat). KnownNat n => Int
natValI @start
_length :: Int
_length = forall (n :: Nat). KnownNat n => Int
natValI @length
onesLike ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
onesLike :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
onesLike Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.ones_like_t Tensor device dtype shape
input
randLike ::
forall shape dtype device.
Tensor device dtype shape ->
IO (Tensor device dtype shape)
randLike :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
randLike = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.rand_like_t
randnLike ::
forall shape dtype device.
Tensor device dtype shape ->
IO (Tensor device dtype shape)
randnLike :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
randnLike = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.randn_like_t
reciprocal ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
reciprocal :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
reciprocal Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.reciprocal_t) Tensor device dtype shape
_input
neg ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
neg :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
neg Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.neg_t Tensor device dtype shape
input
round ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
round :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
round Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.round_t Tensor device dtype shape
input
prelu ::
forall shape dtype device.
Tensor device dtype '[] ->
Tensor device dtype shape ->
Tensor device dtype shape
prelu :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> Tensor device dtype shape -> Tensor device dtype shape
prelu Tensor device dtype '[]
weight Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.prelu_tt Tensor device dtype shape
input Tensor device dtype '[]
weight
type family GeluDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
GeluDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
GeluDTypeIsValid '( 'D.CUDA, deviceIndex) dtype =
( DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype,
DTypeIsNotHalf '( 'D.CUDA, deviceIndex) dtype
)
GeluDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
gelu ::
forall shape dtype device.
(GeluDTypeIsValid device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
gelu :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
GeluDTypeIsValid device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
gelu Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.gelu_t Tensor device dtype shape
input
rsqrt ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
rsqrt :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
rsqrt Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.rsqrt_t Tensor device dtype shape
input
celu ::
forall shape dtype device.
Float ->
Tensor device dtype shape ->
Tensor device dtype shape
celu :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Float -> Tensor device dtype shape -> Tensor device dtype shape
celu Float
alpha Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.celu_ts Tensor device dtype shape
input Float
alpha
type family StackImpl (dim :: Nat) (tensors :: [a]) (count :: Nat) :: Maybe ([Nat], D.DType, (D.DeviceType, Nat)) where
StackImpl dim '[] count = Nothing
StackImpl dim (Tensor device dtype shape ': '[]) count = MaybeTriple (ComputeStackShape shape dim count) (Just dtype) (Just device)
StackImpl dim (Tensor device dtype shape ': Tensor device dtype shape ': tensors) count = StackImpl dim (Tensor device dtype shape ': tensors) (count + 1)
StackImpl _ _ _ = Nothing
type family MaybePair (a' :: Maybe a) (b' :: Maybe b) :: Maybe (a, b) where
MaybePair Nothing _ = Nothing
MaybePair _ Nothing = Nothing
MaybePair (Just a') (Just b') = Just '(a', b')
type family MaybeTriple (a' :: Maybe a) (b' :: Maybe b) (c' :: Maybe c) :: Maybe (a, b, c) where
MaybeTriple Nothing _ _ = Nothing
MaybeTriple _ Nothing _ = Nothing
MaybeTriple _ _ Nothing = Nothing
MaybeTriple (Just a') (Just b') (Just c') = Just '(a', b', c')
type family ComputeStackShape (shape :: [Nat]) (dim :: Nat) (count :: Nat) :: Maybe [Nat] where
ComputeStackShape _ _ 0 = Nothing
ComputeStackShape xs 0 count = Just (count ': xs)
ComputeStackShape (x ': xs) dim count = AppendToMaybe x (ComputeStackShape xs (dim - 1) count)
ComputeStackShape '[] _ _ = Nothing
type family StackCheck (res :: Maybe ([Nat], D.DType, (D.DeviceType, Nat))) :: ([Nat], D.DType, (D.DeviceType, Nat)) where
StackCheck 'Nothing = TypeError (Text "Stacking impossible.")
StackCheck ('Just '(shape, dtype, device)) = '(shape, dtype, device)
type Stack dim tensors = StackCheck (StackImpl dim tensors 1)
stack ::
forall dim shape dtype device tensors.
( KnownNat dim,
'(shape, dtype, device) ~ Stack dim tensors,
ATen.Castable (HList tensors) [D.ATenTensor]
) =>
HList tensors ->
Tensor device dtype shape
stack :: forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
stack HList tensors
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.stack_ll HList tensors
tensors (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)
vecStack ::
forall dim n shape dtype device.
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape) ->
Tensor device dtype (Insert dim n shape)
vecStack :: forall (dim :: Nat) (n :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
vecStack Vector n (Tensor device dtype shape)
tensors = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.stack_ll Vector n (Tensor device dtype shape)
tensors (forall (n :: Nat). KnownNat n => Int
natValI @dim :: Int)
t ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
t :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
t Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.t_t) Tensor device dtype shape
_input
tan ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
tan :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
tan Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.tan_t Tensor device dtype shape
input
trunc ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
trunc :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
trunc Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.trunc_t Tensor device dtype shape
input
type family UnsqueezeImpl (shape :: [a]) (dim :: Nat) :: Maybe [a] where
UnsqueezeImpl xs 0 = Just (1 ': xs)
UnsqueezeImpl (x ': xs) dim = AppendToMaybe x (UnsqueezeImpl xs (dim - 1))
UnsqueezeImpl '[] _ = Nothing
type family UnsqueezeCheck (shape :: [a]) (dim :: Nat) (result :: Maybe [a]) :: [a] where
UnsqueezeCheck shape dim Nothing =
TypeError
( Text "Cannot unsqueeze the tensor since the specified dimension "
:<>: ShowType dim
:<>: Text " is too large (the tensor is only "
:<>: ShowType (ListLength shape)
:<>: Text "D)"
)
UnsqueezeCheck _ _ (Just shape') = shape'
type Unsqueeze shape dim = UnsqueezeCheck shape dim (UnsqueezeImpl shape dim)
unsqueeze ::
forall dim shape shape' dtype device.
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape ->
Tensor device dtype shape'
unsqueeze :: forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.unsqueeze_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)
type family SqueezeAll (shape :: [Nat]) :: [Nat] where
SqueezeAll '[] = '[]
SqueezeAll (1 ': xs) = SqueezeAll xs
SqueezeAll (x ': xs) = x ': SqueezeAll xs
squeezeAll ::
forall shape shape' dtype device.
(shape' ~ SqueezeAll shape) =>
Tensor device dtype shape ->
Tensor device dtype shape'
squeezeAll :: forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ SqueezeAll shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
squeezeAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.squeeze_t Tensor device dtype shape
input
type family SqueezeDimImpl (shape :: [a]) (dim :: Nat) :: Maybe [a] where
SqueezeDimImpl (1 ': xs) 0 = Just xs
SqueezeDimImpl _ 0 = Nothing
SqueezeDimImpl (x ': xs) dim = AppendToMaybe x (SqueezeDimImpl xs (dim - 1))
SqueezeDimImpl _ _ = Nothing
type family SqueezeDimCheck (shape :: [a]) (dim :: Nat) (result :: Maybe [a]) :: [a] where
SqueezeDimCheck shape dim Nothing = TypeError (Text "The tensor cannot be squeezed at the specified dimension " :<>: ShowType dim)
SqueezeDimCheck _ _ ('Just shape') = shape'
type SqueezeDim shape dim = SqueezeDimCheck shape dim (SqueezeDimImpl shape dim)
squeezeDim ::
forall dim shape shape' dtype device.
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape ->
Tensor device dtype shape'
squeezeDim :: forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
squeezeDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.squeeze_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim)
zerosLike ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
zerosLike :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
zerosLike Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.zeros_like_t Tensor device dtype shape
input
clone ::
forall shape dtype device.
Tensor device dtype shape ->
IO (Tensor device dtype shape)
clone :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Tensor device dtype shape)
clone Tensor device dtype shape
input = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.clone_t Tensor device dtype shape
input
addmm ::
forall shape' shape n k m dtype device.
( All KnownNat '[n, k, m],
shape' ~ Broadcast shape '[n, m]
) =>
Float ->
Float ->
Tensor device dtype '[n, k] ->
Tensor device dtype '[k, m] ->
Tensor device dtype shape ->
Tensor device dtype shape'
addmm :: forall (shape' :: [Nat]) (shape :: [Nat]) (n :: Nat) (k :: Nat)
(m :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(All KnownNat '[n, k, m], shape' ~ Broadcast shape '[n, m]) =>
Float
-> Float
-> Tensor device dtype '[n, k]
-> Tensor device dtype '[k, m]
-> Tensor device dtype shape
-> Tensor device dtype shape'
addmm Float
beta Float
alpha Tensor device dtype '[n, k]
mat1 Tensor device dtype '[k, m]
mat2 Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.addmm_tttss Tensor device dtype shape
input Tensor device dtype '[n, k]
mat1 Tensor device dtype '[k, m]
mat2 Float
beta Float
alpha
numel ::
forall shape dtype device.
Tensor device dtype shape ->
Int
numel :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Int
numel Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO Int64
ATen.Managed.tensor_numel Tensor device dtype shape
input
qScale ::
forall shape dtype device.
Tensor device dtype shape ->
Double
qScale :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Double
qScale Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO CDouble
ATen.Managed.q_scale_t Tensor device dtype shape
input
qZeroPoint ::
forall shape dtype device.
Tensor device dtype shape ->
Int
qZeroPoint :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Int
qZeroPoint Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO Int64
ATen.Managed.q_zero_point_t Tensor device dtype shape
input
data RNNDirectionality
=
Bidirectional
|
Unidirectional
deriving (Int -> RNNDirectionality -> ShowS
[RNNDirectionality] -> ShowS
RNNDirectionality -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RNNDirectionality] -> ShowS
$cshowList :: [RNNDirectionality] -> ShowS
show :: RNNDirectionality -> String
$cshow :: RNNDirectionality -> String
showsPrec :: Int -> RNNDirectionality -> ShowS
$cshowsPrec :: Int -> RNNDirectionality -> ShowS
Show, forall x. Rep RNNDirectionality x -> RNNDirectionality
forall x. RNNDirectionality -> Rep RNNDirectionality x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RNNDirectionality x -> RNNDirectionality
$cfrom :: forall x. RNNDirectionality -> Rep RNNDirectionality x
Generic)
type family NumberOfDirections (directionality :: RNNDirectionality) :: Nat where
NumberOfDirections Bidirectional = 2
NumberOfDirections Unidirectional = 1
class KnownRNNDirectionality (directionality :: RNNDirectionality) where
rnnBidirectional :: Bool
instance KnownRNNDirectionality Bidirectional where
rnnBidirectional :: Bool
rnnBidirectional = Bool
True
instance KnownRNNDirectionality Unidirectional where
rnnBidirectional :: Bool
rnnBidirectional = Bool
False
data RNNShapeOrder
=
BatchFirst
|
SequenceFirst
deriving (Int -> RNNShapeOrder -> ShowS
[RNNShapeOrder] -> ShowS
RNNShapeOrder -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RNNShapeOrder] -> ShowS
$cshowList :: [RNNShapeOrder] -> ShowS
show :: RNNShapeOrder -> String
$cshow :: RNNShapeOrder -> String
showsPrec :: Int -> RNNShapeOrder -> ShowS
$cshowsPrec :: Int -> RNNShapeOrder -> ShowS
Show, forall x. Rep RNNShapeOrder x -> RNNShapeOrder
forall x. RNNShapeOrder -> Rep RNNShapeOrder x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RNNShapeOrder x -> RNNShapeOrder
$cfrom :: forall x. RNNShapeOrder -> Rep RNNShapeOrder x
Generic)
class KnownRNNShapeOrder (shapeOrder :: RNNShapeOrder) where
rnnBatchFirst :: Bool
instance KnownRNNShapeOrder BatchFirst where
rnnBatchFirst :: Bool
rnnBatchFirst = Bool
True
instance KnownRNNShapeOrder SequenceFirst where
rnnBatchFirst :: Bool
rnnBatchFirst = Bool
False
type family RNNShape (shapeOrder :: RNNShapeOrder) (seqLen :: Nat) (batchSize :: Nat) (featureSize :: Nat) :: [Nat] where
RNNShape BatchFirst seqLen batchSize featureSize = '[batchSize, seqLen, featureSize]
RNNShape SequenceFirst seqLen batchSize featureSize = '[seqLen, batchSize, featureSize]
type LSTMWIShape hiddenSize inputSize = '[4 * hiddenSize, inputSize]
type LSTMWHShape hiddenSize inputSize = '[4 * hiddenSize, hiddenSize]
type LSTMBIShape hiddenSize inputSize = '[4 * hiddenSize]
type LSTMBHShape hiddenSize inputSize = '[4 * hiddenSize]
type family LSTMRImpl (inputSize :: Nat) (hiddenSize :: Nat) (numLayers :: Nat) (directionality :: RNNDirectionality) :: [[Nat]] where
LSTMRImpl inputSize hiddenSize 1 'Unidirectional =
'[ LSTMWIShape hiddenSize inputSize,
LSTMWHShape hiddenSize inputSize,
LSTMBIShape hiddenSize inputSize,
LSTMBHShape hiddenSize inputSize
]
LSTMRImpl inputSize hiddenSize numLayers 'Unidirectional =
LSTMRImpl inputSize hiddenSize (numLayers - 1) 'Unidirectional
++ '[ LSTMWIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
LSTMWHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
LSTMBIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
LSTMBHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional)
]
LSTMRImpl inputSize hiddenSize 1 'Bidirectional =
'[ LSTMWIShape hiddenSize inputSize,
LSTMWHShape hiddenSize inputSize,
LSTMBIShape hiddenSize inputSize,
LSTMBHShape hiddenSize inputSize,
LSTMWIShape hiddenSize inputSize,
LSTMWHShape hiddenSize inputSize,
LSTMBIShape hiddenSize inputSize,
LSTMBHShape hiddenSize inputSize
]
LSTMRImpl inputSize hiddenSize numLayers 'Bidirectional =
LSTMRImpl inputSize hiddenSize (numLayers - 1) 'Bidirectional
++ '[ LSTMWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
LSTMWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
LSTMBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
LSTMBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
LSTMWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
LSTMWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
LSTMBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
LSTMBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional)
]
type family LSTMR' (shapes :: [[Nat]]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) :: [a] where
LSTMR' '[] dtype device = '[]
LSTMR' (shape ': shapes) dtype device = Tensor device dtype shape ': LSTMR' shapes dtype device
type LSTMR inputSize hiddenSize numLayers directionality dtype device = LSTMR' (LSTMRImpl inputSize hiddenSize numLayers directionality) dtype device
lstm ::
forall
shapeOrder
directionality
numLayers
seqLen
batchSize
inputSize
outputSize
hiddenSize
inputShape
outputShape
hxShape
tensorParameters
dtype
device.
( KnownNat numLayers,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
ATen.Castable (HList tensorParameters) [D.ATenTensor]
) =>
HList tensorParameters ->
Double ->
Bool ->
(Tensor device dtype hxShape, Tensor device dtype hxShape) ->
Tensor device dtype inputShape ->
( Tensor device dtype outputShape,
Tensor device dtype hxShape,
Tensor device dtype hxShape
)
lstm :: forall {k} (shapeOrder :: RNNShapeOrder)
(directionality :: RNNDirectionality) (numLayers :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (inputSize :: Nat)
(outputSize :: Nat) (hiddenSize :: Nat) (inputShape :: [Nat])
(outputShape :: [Nat]) (hxShape :: [Nat]) (tensorParameters :: [k])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hxShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
tensorParameters
~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ForeignPtr Tensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
Tensor device dtype hxShape)
lstm HList tensorParameters
tensorParameters Double
dropoutProb Bool
dropoutOn (Tensor device dtype hxShape
cc, Tensor device dtype hxShape
hc) Tensor device dtype inputShape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
Castable x8 cx8, Castable y cy) =>
(ca
-> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
ATen.cast9
ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor, Tensor)))
ATen.Managed.lstm_tllbldbbb
Tensor device dtype inputShape
input
[Tensor device dtype hxShape]
hx
HList tensorParameters
tensorParameters
Bool
hasBiases
Int64
numLayers
Double
dropoutProb
Bool
dropoutOn
(forall (directionality :: RNNDirectionality).
KnownRNNDirectionality directionality =>
Bool
rnnBidirectional @directionality)
(forall (shapeOrder :: RNNShapeOrder).
KnownRNNShapeOrder shapeOrder =>
Bool
rnnBatchFirst @shapeOrder)
where
hasBiases :: Bool
hasBiases = Bool
True
hx :: [Tensor device dtype hxShape]
hx :: [Tensor device dtype hxShape]
hx = [Tensor device dtype hxShape
cc, Tensor device dtype hxShape
hc]
numLayers :: I.Int64
numLayers :: Int64
numLayers = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @numLayers
lstmCell ::
forall inputSize hiddenSize batchSize dtype device.
Tensor device dtype '[4 * hiddenSize, inputSize] ->
Tensor device dtype '[4 * hiddenSize, hiddenSize] ->
Tensor device dtype '[4 * hiddenSize] ->
Tensor device dtype '[4 * hiddenSize] ->
( Tensor device dtype '[batchSize, hiddenSize],
Tensor device dtype '[batchSize, hiddenSize]
) ->
Tensor device dtype '[batchSize, inputSize] ->
( Tensor device dtype '[batchSize, hiddenSize],
Tensor device dtype '[batchSize, hiddenSize]
)
lstmCell :: forall (inputSize :: Nat) (hiddenSize :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[4 * hiddenSize, inputSize]
-> Tensor device dtype '[4 * hiddenSize, hiddenSize]
-> Tensor device dtype '[4 * hiddenSize]
-> Tensor device dtype '[4 * hiddenSize]
-> (Tensor device dtype '[batchSize, hiddenSize],
Tensor device dtype '[batchSize, hiddenSize])
-> Tensor device dtype '[batchSize, inputSize]
-> (Tensor device dtype '[batchSize, hiddenSize],
Tensor device dtype '[batchSize, hiddenSize])
lstmCell Tensor device dtype '[4 * hiddenSize, inputSize]
wi Tensor device dtype '[4 * hiddenSize, hiddenSize]
wh Tensor device dtype '[4 * hiddenSize]
bi Tensor device dtype '[4 * hiddenSize]
bh (Tensor device dtype '[batchSize, hiddenSize]
cc, Tensor device dtype '[batchSize, hiddenSize]
hc) Tensor device dtype '[batchSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6 ForeignPtr Tensor
-> ForeignPtr TensorList
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.lstm_cell_tltttt Tensor device dtype '[batchSize, inputSize]
input [Tensor device dtype '[batchSize, hiddenSize]]
hx Tensor device dtype '[4 * hiddenSize, inputSize]
wi Tensor device dtype '[4 * hiddenSize, hiddenSize]
wh Tensor device dtype '[4 * hiddenSize]
bi Tensor device dtype '[4 * hiddenSize]
bh
where
hx :: [Tensor device dtype '[batchSize, hiddenSize]]
hx = [Tensor device dtype '[batchSize, hiddenSize]
cc, Tensor device dtype '[batchSize, hiddenSize]
hc] :: [Tensor device dtype '[batchSize, hiddenSize]]
type GRUWIShape hiddenSize inputSize = '[3 * hiddenSize, inputSize]
type GRUWHShape hiddenSize inputSize = '[3 * hiddenSize, hiddenSize]
type GRUBIShape hiddenSize inputSize = '[3 * hiddenSize]
type GRUBHShape hiddenSize inputSize = '[3 * hiddenSize]
type family GRURImpl (inputSize :: Nat) (hiddenSize :: Nat) (numLayers :: Nat) (directionality :: RNNDirectionality) :: [[Nat]] where
GRURImpl inputSize hiddenSize 1 'Unidirectional =
'[ GRUWIShape hiddenSize inputSize,
GRUWHShape hiddenSize inputSize,
GRUBIShape hiddenSize inputSize,
GRUBHShape hiddenSize inputSize
]
GRURImpl inputSize hiddenSize numLayers 'Unidirectional =
GRURImpl inputSize hiddenSize (numLayers - 1) 'Unidirectional
++ '[ GRUWIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
GRUWHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
GRUBIShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional),
GRUBHShape hiddenSize (hiddenSize * NumberOfDirections 'Unidirectional)
]
GRURImpl inputSize hiddenSize 1 'Bidirectional =
'[ GRUWIShape hiddenSize inputSize,
GRUWHShape hiddenSize inputSize,
GRUBIShape hiddenSize inputSize,
GRUBHShape hiddenSize inputSize,
GRUWIShape hiddenSize inputSize,
GRUWHShape hiddenSize inputSize,
GRUBIShape hiddenSize inputSize,
GRUBHShape hiddenSize inputSize
]
GRURImpl inputSize hiddenSize numLayers 'Bidirectional =
GRURImpl inputSize hiddenSize (numLayers - 1) 'Bidirectional
++ '[ GRUWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
GRUWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
GRUBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
GRUBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
GRUWIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
GRUWHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
GRUBIShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional),
GRUBHShape hiddenSize (hiddenSize * NumberOfDirections 'Bidirectional)
]
type family GRUR' (shapes :: [[Nat]]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) :: [a] where
GRUR' '[] dtype device = '[]
GRUR' (shape ': shapes) dtype device = Tensor device dtype shape ': GRUR' shapes dtype device
type GRUR inputSize hiddenSize numLayers directionality dtype device = GRUR' (GRURImpl inputSize hiddenSize numLayers directionality) dtype device
gru ::
forall
shapeOrder
directionality
numLayers
seqLen
batchSize
inputSize
outputSize
hiddenSize
inputShape
outputShape
hcShape
tensorParameters
dtype
device.
( KnownNat numLayers,
KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hcShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
tensorParameters ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
ATen.Castable (HList tensorParameters) [D.ATenTensor]
) =>
HList tensorParameters ->
Double ->
Bool ->
Tensor device dtype hcShape ->
Tensor device dtype inputShape ->
( Tensor device dtype outputShape,
Tensor device dtype hcShape
)
gru :: forall {k} (shapeOrder :: RNNShapeOrder)
(directionality :: RNNDirectionality) (numLayers :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (inputSize :: Nat)
(outputSize :: Nat) (hiddenSize :: Nat) (inputShape :: [Nat])
(outputShape :: [Nat]) (hcShape :: [Nat]) (tensorParameters :: [k])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
KnownRNNDirectionality directionality,
outputSize ~ (hiddenSize * NumberOfDirections directionality),
inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
hcShape
~ '[numLayers * NumberOfDirections directionality, batchSize,
hiddenSize],
tensorParameters
~ GRUR inputSize hiddenSize numLayers directionality dtype device,
Castable (HList tensorParameters) [ForeignPtr Tensor]) =>
HList tensorParameters
-> Double
-> Bool
-> Tensor device dtype hcShape
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gru HList tensorParameters
tensorParameters Double
dropoutProb Bool
dropoutOn Tensor device dtype hcShape
hc Tensor device dtype inputShape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 x8 cx8
y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
Castable x8 cx8, Castable y cy) =>
(ca
-> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> cx8 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> x8 -> IO y
ATen.cast9
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> CBool
-> Int64
-> CDouble
-> CBool
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.gru_ttlbldbbb
Tensor device dtype inputShape
input
Tensor device dtype hcShape
hc
HList tensorParameters
tensorParameters
Bool
hasBiases
Int64
numLayers
Double
dropoutProb
Bool
dropoutOn
(forall (directionality :: RNNDirectionality).
KnownRNNDirectionality directionality =>
Bool
rnnBidirectional @directionality)
(forall (shapeOrder :: RNNShapeOrder).
KnownRNNShapeOrder shapeOrder =>
Bool
rnnBatchFirst @shapeOrder)
where
hasBiases :: Bool
hasBiases = Bool
True
numLayers :: I.Int64
numLayers :: Int64
numLayers = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @numLayers
gruCell ::
forall inputSize hiddenSize batchSize dtype device.
Tensor device dtype '[3 * hiddenSize, inputSize] ->
Tensor device dtype '[3 * hiddenSize, hiddenSize] ->
Tensor device dtype '[3 * hiddenSize] ->
Tensor device dtype '[3 * hiddenSize] ->
Tensor device dtype '[batchSize, hiddenSize] ->
Tensor device dtype '[batchSize, inputSize] ->
Tensor device dtype '[batchSize, hiddenSize]
gruCell :: forall (inputSize :: Nat) (hiddenSize :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[3 * hiddenSize, inputSize]
-> Tensor device dtype '[3 * hiddenSize, hiddenSize]
-> Tensor device dtype '[3 * hiddenSize]
-> Tensor device dtype '[3 * hiddenSize]
-> Tensor device dtype '[batchSize, hiddenSize]
-> Tensor device dtype '[batchSize, inputSize]
-> Tensor device dtype '[batchSize, hiddenSize]
gruCell Tensor device dtype '[3 * hiddenSize, inputSize]
wi Tensor device dtype '[3 * hiddenSize, hiddenSize]
wh Tensor device dtype '[3 * hiddenSize]
bi Tensor device dtype '[3 * hiddenSize]
bh Tensor device dtype '[batchSize, hiddenSize]
hx Tensor device dtype '[batchSize, inputSize]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
ATen.cast6 ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.Managed.gru_cell_tttttt Tensor device dtype '[batchSize, inputSize]
input Tensor device dtype '[batchSize, hiddenSize]
hx Tensor device dtype '[3 * hiddenSize, inputSize]
wi Tensor device dtype '[3 * hiddenSize, hiddenSize]
wh Tensor device dtype '[3 * hiddenSize]
bi Tensor device dtype '[3 * hiddenSize]
bh
type family MatrixOrMatrixBatch (shape :: [Nat]) :: [Nat] where
MatrixOrMatrixBatch (n : m : '[]) = '[n, m]
MatrixOrMatrixBatch (b : n : m : '[]) = '[b, n, m]
MatrixOrMatrixBatch _ = TypeError (Text "The input must be matrix or a batch of matrices.")
triu ::
forall shape dtype device.
(shape ~ MatrixOrMatrixBatch shape) =>
Int ->
Tensor device dtype shape ->
Tensor device dtype shape
triu :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
triu Int
diagonal Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.triu_tl Tensor device dtype shape
input Int
diagonal
tril ::
forall shape dtype device.
(shape ~ MatrixOrMatrixBatch shape) =>
Int ->
Tensor device dtype shape ->
Tensor device dtype shape
tril :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
tril Int
diagonal Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.tril_tl Tensor device dtype shape
input Int
diagonal
trace ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
trace :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
trace Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.trace_t) Tensor device dtype shape
_input
maskedSelect ::
forall shape shape' shape'' dtype device.
(shape'' ~ Broadcast shape shape') =>
Tensor device 'D.Bool shape ->
Tensor device dtype shape' ->
UnknownShapeTensor device dtype
maskedSelect :: forall (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape
-> Tensor device dtype shape' -> UnknownShapeTensor device dtype
maskedSelect Tensor device 'Bool shape
mask Tensor device dtype shape'
input = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor device dtype shape -> UnknownShapeTensor device dtype
UnknownShapeTensor forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.masked_select_tt Tensor device dtype shape'
input Tensor device 'Bool shape
mask
nonzero ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
nonzero :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
nonzero Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.nonzero_t) Tensor device dtype shape
_input
type family GatherDimImpl (shape :: [Nat]) (shape' :: [Nat]) (dim :: Nat) :: Maybe [Nat] where
GatherDimImpl (x ': xs) (y ': xs) 0 = If (1 <=? y) (Just (y ': xs)) Nothing
GatherDimImpl (x ': xs) (x ': ys) dim = AppendToMaybe x (GatherDimImpl xs ys (dim - 1))
GatherDimImpl _ _ _ = Nothing
type family GatherDimCheck (shape :: [a]) (shape' :: [a]) (dim :: Nat) (result :: Maybe [a]) :: [a] where
GatherDimCheck shape shape' dim Nothing =
TypeError
( Text "Cannot gather the tensor at dimension "
:<>: ShowType dim
:<>: Text " using index of shape "
:<>: ShowType shape'
)
GatherDimCheck _ _ _ (Just shape'') = shape''
type GatherDim shape shape' dim = GatherDimCheck shape shape' dim (GatherDimImpl shape shape' dim)
gatherDim ::
forall dim shape shape' dtype device.
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'D.Int64 shape' ->
Tensor device dtype shape ->
Tensor device dtype shape'
gatherDim :: forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
gatherDim Tensor device 'Int64 shape'
index Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.gather_tltb Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @dim) Tensor device 'Int64 shape'
index Bool
False
lgamma ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
lgamma :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
lgamma Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.lgamma_t Tensor device dtype shape
input
digamma ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
digamma :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
digamma Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.digamma_t Tensor device dtype shape
input
polygamma ::
forall shape dtype device.
Int ->
Tensor device dtype shape ->
Tensor device dtype shape
polygamma :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> Tensor device dtype shape -> Tensor device dtype shape
polygamma Int
n Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.polygamma_lt Int
n Tensor device dtype shape
input
erfinv ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
erfinv :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
erfinv Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.erfinv_t Tensor device dtype shape
input
minAll ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype '[]
minAll :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype '[]
minAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.min_t Tensor device dtype shape
input
type family DropValue (shape :: [Nat]) (i :: Nat) :: [Nat] where
DropValue '[] _ = TypeError (Text "Can not find a element in the list.")
DropValue (x : xs) 0 = xs
DropValue (x : xs) i = x ': DropValue xs (i -1)
type family DropNamedValue (shape :: Shape) (i :: Size) :: Shape where
DropNamedValue '[] _ = TypeError (Text "Can not find a element in the list.")
DropNamedValue (x : xs) x = xs
DropNamedValue (x : xs) y = x ': DropNamedValue xs y
minDim ::
forall d shape dtype device.
(KnownNat d) =>
Tensor device dtype shape ->
( Tensor device dtype (DropValue shape d),
Tensor device 'D.Int64 (DropValue shape d)
)
minDim :: forall (d :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownNat d =>
Tensor device dtype shape
-> (Tensor device dtype (DropValue shape d),
Tensor device 'Int64 (DropValue shape d))
minDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.min_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @d)
maxAll ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype '[]
maxAll :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype '[]
maxAll Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.max_t Tensor device dtype shape
input
maxDim ::
forall d shape dtype device.
(KnownNat d) =>
Tensor device dtype shape ->
( Tensor device dtype (DropValue shape d),
Tensor device 'D.Int64 (DropValue shape d)
)
maxDim :: forall (d :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownNat d =>
Tensor device dtype shape
-> (Tensor device dtype (DropValue shape d),
Tensor device 'Int64 (DropValue shape d))
maxDim Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> Int64 -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.max_tl Tensor device dtype shape
input (forall (n :: Nat). KnownNat n => Int
natValI @d)
type family HasDim (dim :: Nat) (shape :: [Nat]) :: Constraint where
HasDim _ '[] = TypeError (Text "The dimension of the argument is incorrect.")
HasDim 0 (_ ': _) = ()
HasDim n (_ ': xs) = HasDim (n -1) xs
sortDim ::
forall dim shape dtype device.
( KnownNat dim,
HasDim dim shape
) =>
Bool ->
Tensor device dtype shape ->
( Tensor device dtype shape,
Tensor device D.Int64 shape
)
sortDim :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, HasDim dim shape) =>
Bool
-> Tensor device dtype shape
-> (Tensor device dtype shape, Tensor device 'Int64 shape)
sortDim Bool
_descending Tensor device dtype shape
_input =
let (Tensor
a, Tensor
b) = Tensor -> (Tensor, Tensor)
func (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
_input)
in (forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
a, forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
b)
where
func :: D.Tensor -> (D.Tensor, D.Tensor)
func :: Tensor -> (Tensor, Tensor)
func Tensor
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.sort_tlb) Tensor
_input Int
_dim Bool
_descending
_dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim
sortNamedDim ::
forall dim shape dtype device.
( KnownNat (FindDim dim shape)
) =>
Bool ->
NamedTensor device dtype shape ->
( NamedTensor device dtype shape,
NamedTensor device D.Int64 shape
)
sortNamedDim :: forall (dim :: Size) (shape :: Shape) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownNat (FindDim dim shape) =>
Bool
-> NamedTensor device dtype shape
-> (NamedTensor device dtype shape,
NamedTensor device 'Int64 shape)
sortNamedDim Bool
_descending NamedTensor device dtype shape
_input =
let (Tensor
a, Tensor
b) = Tensor -> (Tensor, Tensor)
func (forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype shape
_input)
in (forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: Shape) (shape1 :: [Nat]).
(shape1 ~ ToNats shape) =>
Tensor device dtype shape1 -> NamedTensor device dtype shape
FromTensor forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
a, forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: Shape) (shape1 :: [Nat]).
(shape1 ~ ToNats shape) =>
Tensor device dtype shape1 -> NamedTensor device dtype shape
FromTensor forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
b)
where
func :: D.Tensor -> (D.Tensor, D.Tensor)
func :: Tensor -> (Tensor, Tensor)
func Tensor
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> Int64 -> CBool -> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.sort_tlb) Tensor
_input Int
_dim Bool
_descending
_dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @(FindDim dim shape)
argSortDim ::
forall dim shape dtype device.
( KnownNat dim,
HasDim dim shape
) =>
Bool ->
Tensor device dtype shape ->
Tensor device D.Int64 shape
argSortDim :: forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, HasDim dim shape) =>
Bool -> Tensor device dtype shape -> Tensor device 'Int64 shape
argSortDim Bool
_descending Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argsort_tlb) Tensor device dtype shape
_input Int
_dim Bool
_descending
where
_dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim
argSortNamedDim ::
forall dim shape dtype device.
( KnownNat (FindDim dim shape)
) =>
Bool ->
NamedTensor device dtype shape ->
NamedTensor device D.Int64 shape
argSortNamedDim :: forall (dim :: Size) (shape :: Shape) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownNat (FindDim dim shape) =>
Bool
-> NamedTensor device dtype shape
-> NamedTensor device 'Int64 shape
argSortNamedDim Bool
_descending NamedTensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.argsort_tlb) NamedTensor device dtype shape
_input Int
_dim Bool
_descending
where
_dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @(FindDim dim shape)
type family TopKCheck (k :: Nat) (shape :: [Nat]) (dim :: Nat) (satd :: Maybe Nat) (result :: Maybe a) :: a where
TopKCheck _ shape dim _ Nothing = DimOutOfBound shape dim
TopKCheck _ shape dim Nothing _ = DimOutOfBound shape dim
TopKCheck k shape dim (Just v) (Just result) = If (k <=? v) result (TypeError (Text "k must be less than or equal to the number of elements in the requested dimension."))
type TopK k shape dim = TopKCheck k shape dim (ExtractDim dim shape) (ReplaceDim dim shape k)
type family TopKDeviceAndDTypeCheck dtype (device :: (D.DeviceType, Nat)) :: Constraint where
TopKDeviceAndDTypeCheck D.Bool _ = (TypeError (Text "topk is not defined for Bool tensors."))
TopKDeviceAndDTypeCheck D.Half '(D.CPU, _) = (TypeError (Text "topk is not defined for Half types on CPU."))
TopKDeviceAndDTypeCheck _ _ = ()
topk ::
forall k dim shape' shape dtype device.
( KnownNat k,
KnownNat dim,
All KnownNat shape,
TopKDeviceAndDTypeCheck dtype device,
shape' ~ TopK k shape dim
) =>
Bool ->
Bool ->
Tensor device dtype shape ->
(Tensor device dtype shape', Tensor device 'D.Int64 shape')
topk :: forall (k :: Nat) (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat k, KnownNat dim, All KnownNat shape,
TopKDeviceAndDTypeCheck dtype device, shape' ~ TopK k shape dim) =>
Bool
-> Bool
-> Tensor device dtype shape
-> (Tensor device dtype shape', Tensor device 'Int64 shape')
topk Bool
_largest Bool
_sorted Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5 ForeignPtr Tensor
-> Int64
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.topk_tllbb) Tensor device dtype shape
_input Int
_k Int
_dim Bool
_largest Bool
_sorted
where
_k :: Int
_k = forall (n :: Nat). KnownNat n => Int
natValI @k
_dim :: Int
_dim = forall (n :: Nat). KnownNat n => Int
natValI @dim
alias ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
alias :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
alias Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.alias_t) Tensor device dtype shape
_input
l1Loss ::
forall reduction shape dtype device.
(KnownReduction reduction) =>
Tensor device dtype shape ->
Tensor device dtype shape ->
Tensor device dtype (ConditionalReduction shape reduction)
l1Loss :: forall (reduction :: Reduction) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownReduction reduction =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype (ConditionalReduction shape reduction)
l1Loss Tensor device dtype shape
prediction Tensor device dtype shape
target =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.l1_loss_ttl Tensor device dtype shape
prediction Tensor device dtype shape
target (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
nllLoss ::
forall reduction n c ds dtype device.
(KnownReduction reduction, KnownNat n, KnownNat c, KnownShape ds) =>
Tensor device dtype '[c] ->
Int ->
Tensor device dtype (n ': c ': ds) ->
Tensor device 'D.Int64 (n ': ds) ->
Tensor device dtype (ConditionalReduction (n ': ds) reduction)
nllLoss :: forall (reduction :: Reduction) (n :: Nat) (c :: Nat) (ds :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownReduction reduction, KnownNat n, KnownNat c,
KnownShape ds) =>
Tensor device dtype '[c]
-> Int
-> Tensor device dtype (n : c : ds)
-> Tensor device 'Int64 (n : ds)
-> Tensor device dtype (ConditionalReduction (n : ds) reduction)
nllLoss Tensor device dtype '[c]
weight Int
ignoreIndex Tensor device dtype (n : c : ds)
prediction Tensor device 'Int64 (n : ds)
target = case forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @ds of
[] ->
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.nll_loss_tttll
Tensor device dtype (n : c : ds)
prediction
Tensor device 'Int64 (n : ds)
target
Tensor device dtype '[c]
weight
(forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
Int
ignoreIndex
[Int
_h, Int
_w] ->
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.nll_loss2d_tttll
Tensor device dtype (n : c : ds)
prediction
Tensor device 'Int64 (n : ds)
target
Tensor device dtype '[c]
weight
(forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
Int
ignoreIndex
Int
h : [Int]
t -> case forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction of
Int
0 -> forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int] -> Tensor -> Tensor
D.reshape ((forall (n :: Nat). KnownNat n => Int
natValI @n) forall a. a -> [a] -> [a]
: Int
h forall a. a -> [a] -> [a]
: [Int]
t)) forall a b. (a -> b) -> a -> b
$ Tensor
out
Int
_ -> forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor Tensor
out
where
t' :: [Int]
t' = [Int
1, forall (t :: Size) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Num a => a -> a -> a
(*) Int
h [Int]
t]
input' :: Tensor
input' = [Int] -> Tensor -> Tensor
D.reshape (forall (n :: Nat). KnownNat n => Int
natValI @n forall a. a -> [a] -> [a]
: forall (n :: Nat). KnownNat n => Int
natValI @c forall a. a -> [a] -> [a]
: [Int]
t') (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype (n : c : ds)
prediction)
target' :: Tensor
target' = [Int] -> Tensor -> Tensor
D.reshape (forall (n :: Nat). KnownNat n => Int
natValI @n forall a. a -> [a] -> [a]
: [Int]
t') (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device 'Int64 (n : ds)
target)
out :: Tensor
out =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.nll_loss2d_tttll
Tensor
input'
Tensor
target'
Tensor device dtype '[c]
weight
(forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
Int
ignoreIndex
smoothL1Loss ::
forall reduction shape dtype device.
(KnownReduction reduction) =>
Tensor device dtype shape ->
Tensor device dtype shape ->
Tensor device dtype (ConditionalReduction shape reduction)
smoothL1Loss :: forall (reduction :: Reduction) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownReduction reduction =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype (ConditionalReduction shape reduction)
smoothL1Loss Tensor device dtype shape
prediction Tensor device dtype shape
target =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.smooth_l1_loss_ttl Tensor device dtype shape
prediction Tensor device dtype shape
target (forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
softMarginLoss ::
forall reduction shape dtype device.
(KnownReduction reduction) =>
Tensor device dtype shape ->
Tensor device dtype shape ->
Tensor device dtype (ConditionalReduction shape reduction)
softMarginLoss :: forall (reduction :: Reduction) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
KnownReduction reduction =>
Tensor device dtype shape
-> Tensor device dtype shape
-> Tensor device dtype (ConditionalReduction shape reduction)
softMarginLoss Tensor device dtype shape
prediciton Tensor device dtype shape
target =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.Managed.soft_margin_loss_ttl
Tensor device dtype shape
prediciton
Tensor device dtype shape
target
(forall {k} (reduction :: k). KnownReduction reduction => Int
reductionVal @reduction)
elu ::
forall shape dtype a device.
(D.Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
a ->
a ->
a ->
Tensor device dtype shape ->
Tensor device dtype shape
elu :: forall (shape :: [Nat]) (dtype :: DType) a
(device :: (DeviceType, Nat)).
(Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
a
-> a -> a -> Tensor device dtype shape -> Tensor device dtype shape
elu a
alpha a
scale a
inputScale Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> ForeignPtr Scalar
-> IO (ForeignPtr Tensor)
ATen.Managed.elu_tsss Tensor device dtype shape
input a
alpha a
scale a
inputScale
hardTanh ::
forall shape dtype device.
Float ->
Float ->
Tensor device dtype shape ->
Tensor device dtype shape
hardTanh :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Float
-> Float -> Tensor device dtype shape -> Tensor device dtype shape
hardTanh Float
min_val Float
max_val Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.hardtanh_tss Tensor device dtype shape
input Float
min_val Float
max_val
leakyRelu ::
forall a shape dtype device.
(D.Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
a ->
Tensor device dtype shape ->
Tensor device dtype shape
leakyRelu :: forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(Scalar a, StandardFloatingPointDTypeValidation device dtype) =>
a -> Tensor device dtype shape -> Tensor device dtype shape
leakyRelu a
negativeSlope Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.leaky_relu_ts Tensor device dtype shape
input a
negativeSlope
logSigmoid ::
forall shape dtype device.
(StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape ->
Tensor device dtype shape
logSigmoid :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
logSigmoid Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.Managed.log_sigmoid_t Tensor device dtype shape
input
softplus ::
forall a shape dtype device.
D.Scalar a =>
a ->
a ->
Tensor device dtype shape ->
Tensor device dtype shape
softplus :: forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> a -> Tensor device dtype shape -> Tensor device dtype shape
softplus a
beta a
threshold Tensor device dtype shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Scalar -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.softplus_tss Tensor device dtype shape
input a
beta a
threshold
softShrink ::
forall shape dtype device.
Float ->
Tensor device dtype shape ->
Tensor device dtype shape
softShrink :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Float -> Tensor device dtype shape -> Tensor device dtype shape
softShrink Float
lambda Tensor device dtype shape
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Scalar -> IO (ForeignPtr Tensor)
ATen.Managed.softshrink_ts Tensor device dtype shape
input Float
lambda
adaptiveAvgPool2d ::
forall outputSize channelSize inputSize0 inputSize1 batchSize dtype device.
( All
KnownNat
'[ channelSize,
inputSize0,
inputSize1,
batchSize,
Torch.Typed.Auxiliary.Fst outputSize,
Torch.Typed.Auxiliary.Snd outputSize
]
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize]
adaptiveAvgPool2d :: forall (outputSize :: (Nat, Nat)) (channelSize :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
All
KnownNat
'[channelSize, inputSize0, inputSize1, batchSize, Fst outputSize,
Snd outputSize] =>
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, channelSize, Fst outputSize, Snd outputSize]
adaptiveAvgPool2d Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool2d_tl
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst outputSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd outputSize)] :: [Int])
mkldnnAdaptiveAvgPool2d ::
forall outputSize channelSize inputSize0 inputSize1 batchSize dtype device.
( All
KnownNat
'[ channelSize,
inputSize0,
inputSize1,
batchSize,
Torch.Typed.Auxiliary.Fst outputSize,
Torch.Typed.Auxiliary.Snd outputSize
]
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize]
mkldnnAdaptiveAvgPool2d :: forall (outputSize :: (Nat, Nat)) (channelSize :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
All
KnownNat
'[channelSize, inputSize0, inputSize1, batchSize, Fst outputSize,
Snd outputSize] =>
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, channelSize, Fst outputSize, Snd outputSize]
mkldnnAdaptiveAvgPool2d Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool2d_tl
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst outputSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd outputSize)] :: [Int])
adaptiveAvgPool3d ::
forall
outputSize
channelSize
inputSize0
inputSize1
inputSize2
batchSize
dtype
device.
( All
KnownNat
'[ channelSize,
inputSize0,
inputSize1,
inputSize2,
batchSize,
Fst3 outputSize,
Snd3 outputSize,
Trd3 outputSize
]
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
Tensor device dtype '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize, Trd3 outputSize]
adaptiveAvgPool3d :: forall (outputSize :: (Nat, Nat, Nat)) (channelSize :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All
KnownNat
'[channelSize, inputSize0, inputSize1, inputSize2, batchSize,
Fst3 outputSize, Snd3 outputSize, Trd3 outputSize] =>
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
device
dtype
'[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize,
Trd3 outputSize]
adaptiveAvgPool3d Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.adaptive_avg_pool3d_tl
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 outputSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 outputSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 outputSize)
] ::
[Int]
)
adaptiveMaxPool2d ::
forall outputSize channelSize inputSize0 inputSize1 batchSize dtype device.
( All
KnownNat
'[ channelSize,
inputSize0,
inputSize1,
batchSize,
Torch.Typed.Auxiliary.Fst outputSize,
Torch.Typed.Auxiliary.Snd outputSize
]
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
( Tensor device dtype '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize],
Tensor device 'D.Int64 '[batchSize, channelSize, Torch.Typed.Auxiliary.Fst outputSize, Torch.Typed.Auxiliary.Snd outputSize]
)
adaptiveMaxPool2d :: forall (outputSize :: (Nat, Nat)) (channelSize :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
All
KnownNat
'[channelSize, inputSize0, inputSize1, batchSize, Fst outputSize,
Snd outputSize] =>
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> (Tensor
device
dtype
'[batchSize, channelSize, Fst outputSize, Snd outputSize],
Tensor
device
'Int64
'[batchSize, channelSize, Fst outputSize, Snd outputSize])
adaptiveMaxPool2d Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.adaptive_max_pool2d_tl
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst outputSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd outputSize)] :: [Int])
adaptiveMaxPool3d ::
forall
outputSize
channelSize
inputSize0
inputSize1
inputSize2
batchSize
dtype
device.
( All
KnownNat
'[ channelSize,
inputSize0,
inputSize1,
inputSize2,
batchSize,
Fst3 outputSize,
Snd3 outputSize,
Trd3 outputSize
]
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
( Tensor device dtype '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize, Trd3 outputSize],
Tensor device 'D.Int64 '[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize, Trd3 outputSize]
)
adaptiveMaxPool3d :: forall (outputSize :: (Nat, Nat, Nat)) (channelSize :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
All
KnownNat
'[channelSize, inputSize0, inputSize1, inputSize2, batchSize,
Fst3 outputSize, Snd3 outputSize, Trd3 outputSize] =>
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> (Tensor
device
dtype
'[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize,
Trd3 outputSize],
Tensor
device
'Int64
'[batchSize, channelSize, Fst3 outputSize, Snd3 outputSize,
Trd3 outputSize])
adaptiveMaxPool3d Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
(forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr IntArray
-> IO (ForeignPtr (StdTuple '(Tensor, Tensor)))
ATen.Managed.adaptive_max_pool3d_tl)
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 outputSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 outputSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 outputSize)
] ::
[Int]
)
avgPool2d ::
forall
kernelSize
stride
padding
channelSize
inputSize0
inputSize1
batchSize
outputSize0
outputSize1
dtype
device.
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst kernelSize,
Torch.Typed.Auxiliary.Snd kernelSize,
Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
channelSize,
inputSize0,
inputSize1,
batchSize
],
ConvSideCheck inputSize0 (Torch.Typed.Auxiliary.Fst kernelSize) (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 (Torch.Typed.Auxiliary.Snd kernelSize) (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1] ->
Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1]
avgPool2d :: forall (kernelSize :: (Nat, Nat)) (stride :: (Nat, Nat))
(padding :: (Nat, Nat)) (channelSize :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst kernelSize, Snd kernelSize, Fst stride, Snd stride,
Fst padding, Snd padding, channelSize, inputSize0, inputSize1,
batchSize],
ConvSideCheck
inputSize0 (Fst kernelSize) (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1
(Snd kernelSize)
(Snd stride)
(Snd padding)
outputSize1) =>
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
-> Tensor
device dtype '[batchSize, channelSize, outputSize0, outputSize1]
avgPool2d Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.avg_pool2d_tlllbbl
Tensor
device dtype '[batchSize, channelSize, inputSize0, inputSize1]
input
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst kernelSize), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd kernelSize)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
Bool
False
Bool
True
(Int
1 :: Int)
avgPool3d ::
forall
kernelSize
stride
padding
channelSize
inputSize0
inputSize1
inputSize2
batchSize
outputSize0
outputSize1
outputSize2
dtype
device.
( All
KnownNat
'[ Fst3 kernelSize,
Snd3 kernelSize,
Trd3 kernelSize,
Fst3 stride,
Snd3 stride,
Trd3 stride,
Fst3 padding,
Snd3 padding,
Trd3 padding,
channelSize,
inputSize0,
inputSize1,
inputSize2,
batchSize
],
ConvSideCheck inputSize0 (Fst3 kernelSize) (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck inputSize1 (Snd3 kernelSize) (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck inputSize2 (Trd3 kernelSize) (Trd3 stride) (Trd3 padding) outputSize2
) =>
Tensor device dtype '[batchSize, channelSize, inputSize0, inputSize1, inputSize2] ->
Tensor device dtype '[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
avgPool3d :: forall (kernelSize :: (Nat, Nat, Nat)) (stride :: (Nat, Nat, Nat))
(padding :: (Nat, Nat, Nat)) (channelSize :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
(outputSize2 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst3 kernelSize, Snd3 kernelSize, Trd3 kernelSize, Fst3 stride,
Snd3 stride, Trd3 stride, Fst3 padding, Snd3 padding, Trd3 padding,
channelSize, inputSize0, inputSize1, inputSize2, batchSize],
ConvSideCheck
inputSize0
(Fst3 kernelSize)
(Fst3 stride)
(Fst3 padding)
outputSize0,
ConvSideCheck
inputSize1
(Snd3 kernelSize)
(Snd3 stride)
(Snd3 padding)
outputSize1,
ConvSideCheck
inputSize2
(Trd3 kernelSize)
(Trd3 stride)
(Trd3 padding)
outputSize2) =>
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
device
dtype
'[batchSize, channelSize, outputSize0, outputSize1, outputSize2]
avgPool3d Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> CBool
-> CBool
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.avg_pool3d_tlllbbl
Tensor
device
dtype
'[batchSize, channelSize, inputSize0, inputSize1, inputSize2]
input
( [ forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 kernelSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 kernelSize),
forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 kernelSize)
] ::
[Int]
)
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 stride), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
natValI @(Fst3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Snd3 padding), forall (n :: Nat). KnownNat n => Int
natValI @(Trd3 padding)] :: [Int])
Bool
False
Bool
True
(Int
1 :: Int)
type family Upsample2dCheck shape h w where
Upsample2dCheck (b : c : w : h : '[]) h' w' =
If
(h <=? h')
( If
(w <=? w')
(b : c : w' : h' : '[])
(TypeError (Text "Target width must be greater than current width!"))
)
(TypeError (Text "Target height must be greater than current height!"))
Upsample2dCheck _ _ _ = TypeError (Text "Shape must be 4 dimensional!")
type Upsample2d shape h w = Upsample2dCheck shape h w
upsample_bilinear2d ::
forall w h shape dtype device.
(KnownNat h, KnownNat w, All KnownNat shape) =>
Bool ->
Tensor device dtype shape ->
Tensor device dtype (Upsample2d shape h w)
upsample_bilinear2d :: forall (w :: Nat) (h :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat h, KnownNat w, All KnownNat shape) =>
Bool
-> Tensor device dtype shape
-> Tensor device dtype (Upsample2d shape h w)
upsample_bilinear2d Bool
_align_corners Tensor device dtype shape
_input =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.upsample_bilinear2d_tlb) Tensor device dtype shape
_input ([Int
w, Int
h] :: [Int]) Bool
_align_corners
where
w :: Int
w = forall (n :: Nat). KnownNat n => Int
natValI @w :: Int
h :: Int
h = forall (n :: Nat). KnownNat n => Int
natValI @h :: Int
upsample_bicubic2d ::
forall w h shape dtype device.
(KnownNat h, KnownNat w, All KnownNat shape) =>
Bool ->
Tensor device dtype shape ->
Tensor device dtype (Upsample2d shape h w)
upsample_bicubic2d :: forall (w :: Nat) (h :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat h, KnownNat w, All KnownNat shape) =>
Bool
-> Tensor device dtype shape
-> Tensor device dtype (Upsample2d shape h w)
upsample_bicubic2d Bool
_align_corners Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.Managed.upsample_bicubic2d_tlb) Tensor device dtype shape
_input ([Int
w, Int
h] :: [Int]) Bool
_align_corners
where
w :: Int
w = forall (n :: Nat). KnownNat n => Int
natValI @w :: Int
h :: Int
h = forall (n :: Nat). KnownNat n => Int
natValI @h :: Int
upsample_nearest2d ::
forall w h shape dtype device.
(KnownNat h, KnownNat w, All KnownNat shape) =>
Tensor device dtype shape ->
Tensor device dtype (Upsample2d shape h w)
upsample_nearest2d :: forall (w :: Nat) (h :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat h, KnownNat w, All KnownNat shape) =>
Tensor device dtype shape
-> Tensor device dtype (Upsample2d shape h w)
upsample_nearest2d Tensor device dtype shape
_input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.Managed.upsample_nearest2d_tl) Tensor device dtype shape
_input ([Int
w, Int
h] :: [Int])
where
w :: Int
w = forall (n :: Nat). KnownNat n => Int
natValI @w :: Int
h :: Int
h = forall (n :: Nat). KnownNat n => Int
natValI @h :: Int