{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.Tensor.Creation
  ( sOnes,
    ones,
    sZeros,
    zeros,
    sFull,
    full,
    sRandn,
    randn,
    sArangeNaturals,
    arangeNaturals,
    sEye,
    eye,
    sEyeSquare,
    eyeSquare,
  )
where

import Control.Monad.Catch (MonadThrow)
import Data.Monoid (All (..))
import Data.Singletons (SingI (..), SingKind (fromSing))
import Torch.GraduallyTyped.DType (SDataType)
import Torch.GraduallyTyped.Device (SDevice)
import Torch.GraduallyTyped.Internal.TensorOptions (tensorDims, tensorOptions)
import Torch.GraduallyTyped.Layout (SLayout (..))
import Torch.GraduallyTyped.Prelude (forgetIsChecked)
import Torch.GraduallyTyped.Random (Generator (..), SGetGeneratorDevice, sCreateWithGenerator)
import Torch.GraduallyTyped.RequiresGradient (SGradient)
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SSize, Shape (..), dimName, dimSize)
import Torch.GraduallyTyped.Tensor.Type (Tensor (..), TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.Internal.Cast (cast2, cast3, cast4)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.TensorFactories as ATen

-- $setup
-- >>> import Data.Int (Int16)
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped

-- | Create a gradually typed tensor of ones.
--
-- >>> shape = SShape $ SName @"batch" :&: SSize @32 :|: SUncheckedName "feature" :&: SUncheckedSize 8 :|: SNil
-- >>> :type sOnes $ TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt64) shape
-- sOnes $ TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt64) shape
--   :: MonadThrow m =>
--      m (Tensor
--           ('Gradient 'WithoutGradient)
--           ('Layout 'Dense)
--           ('Device 'CPU)
--           ('DataType 'Int64)
--           ('Shape
--              '[ 'Dim ('Name "batch") ('Size 32),
--                 'Dim 'UncheckedName 'UncheckedSize]))
sOnes ::
  forall gradient layout device dataType shape m.
  MonadThrow m =>
  TensorSpec gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape)
sOnes :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
..} = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
tsGradient SLayout layout
tsLayout SDevice device
tsDevice SDataType dataType
tsDataType
      dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims SShape shape
tsShape
  ForeignPtr Tensor
tPtr <- case (forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> name
dimName [Dim String Integer]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> size
dimSize [Dim String Integer]
dims) of
    ([String]
names, [Integer]
sizes)
      | All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Bool -> All
All forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Eq a => a -> a -> Bool
== String
"*")) forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.ones_lo [Integer]
sizes TensorOptions
opts
      | Bool
otherwise -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.ones_lNo [Integer]
sizes [String]
names TensorOptions
opts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr

-- | Create a typed tensor of ones.
--
-- >>> ones :: IO (CPUParameter ('DataType 'Float) ('Shape '[]))
-- Tensor Float []  1.0000
-- >>> ones :: IO (CPUTensor ('DataType 'Int64) ('Shape '[ 'Dim ('Name "*") ('Size 1)]))
-- Tensor Int64 [1] [ 1]
ones ::
  forall gradient layout device dataType shape m.
  MonadThrow m =>
  (SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape) =>
  m (Tensor gradient layout device dataType shape)
ones :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, SingI gradient, SingI layout, SingI device,
 SingI dataType, SingI shape) =>
m (Tensor gradient layout device dataType shape)
ones = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)

-- | Create a gradually typed tensor of zeros.
--
-- >>> shape = SShape $ SName @"batch" :&: SSize @32 :|: SUncheckedName "feature" :&: SUncheckedSize 8 :|: SNil
-- >>> :type sZeros $ TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt64) shape
-- sZeros $ TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt64) shape
--   :: MonadThrow m =>
--      m (Tensor
--           ('Gradient 'WithoutGradient)
--           ('Layout 'Dense)
--           ('Device 'CPU)
--           ('DataType 'Int64)
--           ('Shape
--              '[ 'Dim ('Name "batch") ('Size 32),
--                 'Dim 'UncheckedName 'UncheckedSize]))
sZeros ::
  forall gradient layout device dataType shape m.
  MonadThrow m =>
  TensorSpec gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape)
sZeros :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
tsGradient SLayout layout
tsLayout SDevice device
tsDevice SDataType dataType
tsDataType
      dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims SShape shape
tsShape
  ForeignPtr Tensor
tPtr <- case (forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> name
dimName [Dim String Integer]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> size
dimSize [Dim String Integer]
dims) of
    ([String]
names, [Integer]
sizes)
      | All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Bool -> All
All forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Eq a => a -> a -> Bool
== String
"*")) forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.zeros_lo [Integer]
sizes TensorOptions
opts
      | Bool
otherwise -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.zeros_lNo [Integer]
sizes [String]
names TensorOptions
opts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr

-- | Create a typed tensor of zeros.
--
-- >>> zeros :: IO (CPUParameter ('DataType 'Float) ('Shape '[]))
-- Tensor Float []  0.0000
-- >>> zeros :: IO (CPUTensor ('DataType 'Int64) ('Shape '[ 'Dim ('Name "*") ('Size 1)]))
-- Tensor Int64 [1] [ 0]
zeros ::
  forall gradient layout device dataType shape m.
  MonadThrow m =>
  (SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape) =>
  m (Tensor gradient layout device dataType shape)
zeros :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, SingI gradient, SingI layout, SingI device,
 SingI dataType, SingI shape) =>
m (Tensor gradient layout device dataType shape)
zeros = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)

-- | Create a gradually typed tensor filled with a given scalar value.
--
-- >>> shape = SShape $ SName @"batch" :&: SSize @32 :|: SUncheckedName "feature" :&: SUncheckedSize 8 :|: SNil
-- >>> input = -1
-- >>> :type sFull (TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt64) shape) input
-- sFull (TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt64) shape) input
--   :: MonadThrow m =>
--      m (Tensor
--           ('Gradient 'WithoutGradient)
--           ('Layout 'Dense)
--           ('Device 'CPU)
--           ('DataType 'Int64)
--           ('Shape
--              '[ 'Dim ('Name "batch") ('Size 32),
--                 'Dim 'UncheckedName 'UncheckedSize]))
sFull ::
  forall gradient layout device dataType shape input m.
  (MonadThrow m, Scalar input) =>
  TensorSpec gradient layout device dataType shape ->
  input ->
  m (Tensor gradient layout device dataType shape)
sFull :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
       (m :: * -> *).
(MonadThrow m, Scalar input) =>
TensorSpec gradient layout device dataType shape
-> input -> m (Tensor gradient layout device dataType shape)
sFull TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} input
input = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
tsGradient SLayout layout
tsLayout SDevice device
tsDevice SDataType dataType
tsDataType
      dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims SShape shape
tsShape
  ForeignPtr Tensor
tPtr <- case (forall name size. Dim name size -> name
dimName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim String Integer]
dims, forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim String Integer]
dims) of
    ([String]
names, [Integer]
sizes)
      | All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\String
name -> Bool -> All
All forall a b. (a -> b) -> a -> b
$ String
name forall a. Eq a => a -> a -> Bool
== String
"*") forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.full_lso [Integer]
sizes input
input TensorOptions
opts
      | Bool
otherwise -> forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.full_lsNo [Integer]
sizes input
input [String]
names TensorOptions
opts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr

-- | Create a typed tensor filled with a given scalar value.
--
-- >>> full (-1) :: IO (CPUParameter ('DataType 'Float) ('Shape '[]))
-- Tensor Float [] -1.0000
-- >>> full (-1) :: IO (CPUTensor ('DataType 'Int64) ('Shape '[ 'Dim ('Name "*") ('Size 1)]))
-- Tensor Int64 [1] [-1]
full ::
  forall gradient layout device dataType shape input m.
  (MonadThrow m, SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape, Scalar input) =>
  input ->
  m (Tensor gradient layout device dataType shape)
full :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
       (m :: * -> *).
(MonadThrow m, SingI gradient, SingI layout, SingI device,
 SingI dataType, SingI shape, Scalar input) =>
input -> m (Tensor gradient layout device dataType shape)
full = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
       (m :: * -> *).
(MonadThrow m, Scalar input) =>
TensorSpec gradient layout device dataType shape
-> input -> m (Tensor gradient layout device dataType shape)
sFull forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)

-- | Create a gradually typed random tensor.
sRandn ::
  forall gradient layout device dataType shape generatorDevice m.
  (SGetGeneratorDevice generatorDevice, MonadThrow m) =>
  TensorSpec gradient layout device dataType shape ->
  Generator generatorDevice ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sRandn :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tSpec Generator generatorDevice
generator =
  forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> (ForeignPtr TensorOptions
    -> [Dim String Integer]
    -> ForeignPtr Generator
    -> IO (ForeignPtr Tensor))
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sCreateWithGenerator TensorSpec gradient layout device dataType shape
tSpec Generator generatorDevice
generator forall a b. (a -> b) -> a -> b
$
    \ForeignPtr TensorOptions
opts [Dim String Integer]
dims ForeignPtr Generator
genPtr ->
      case (forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> name
dimName [Dim String Integer]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> size
dimSize [Dim String Integer]
dims) of
        ([String]
names, [Integer]
sizes)
          | All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\String
name -> Bool -> All
All forall a b. (a -> b) -> a -> b
$ String
name forall a. Eq a => a -> a -> Bool
== String
"*") forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr Generator
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.randn_lGo [Integer]
sizes ForeignPtr Generator
genPtr ForeignPtr TensorOptions
opts
          | Bool
otherwise -> forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr IntArray
-> ForeignPtr Generator
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.randn_lGNo [Integer]
sizes ForeignPtr Generator
genPtr [String]
names ForeignPtr TensorOptions
opts

-- | Create typed random tensor.
randn ::
  forall gradient layout device dataType shape generatorDevice m.
  (SGetGeneratorDevice generatorDevice, MonadThrow m) =>
  (SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape) =>
  Generator generatorDevice ->
  m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
randn :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m, SingI gradient,
 SingI layout, SingI device, SingI dataType, SingI shape) =>
Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
randn = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)

-- | Create a gradually typed one-dimensional tensor of the numbers @0@ to @size -1@.
sArangeNaturals ::
  forall m gradient layout device dataType size shape.
  (MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
  SGradient gradient ->
  SLayout layout ->
  SDevice device ->
  SDataType dataType ->
  SSize size ->
  m (Tensor gradient layout device dataType shape)
sArangeNaturals :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (size :: Size Nat)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sArangeNaturals SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SSize size
size = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType
      size' :: Integer
size' = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize size
size
  ForeignPtr Tensor
tPtr <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Scalar
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.arange_so Integer
size' TensorOptions
opts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr

-- | Create a typed one-dimensional tensor of the numbers @0@ to @size -1@.
arangeNaturals ::
  forall size gradient layout device dataType shape m.
  ( MonadThrow m,
    shape ~ 'Shape '[ 'Dim ('Name "*") size],
    SingI gradient,
    SingI layout,
    SingI device,
    SingI dataType,
    SingI size
  ) =>
  m (Tensor gradient layout device dataType shape)
arangeNaturals :: forall (size :: Size Nat) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size],
 SingI gradient, SingI layout, SingI device, SingI dataType,
 SingI size) =>
m (Tensor gradient layout device dataType shape)
arangeNaturals = forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (size :: Size Nat)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sArangeNaturals (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @size)

-- | Create a gradually typed rectangular tensor with ones on the diagonal and zeros elsewhere.
sEye ::
  forall gradient layout device dataType rows cols shape m.
  (MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols]) =>
  SGradient gradient ->
  SLayout layout ->
  SDevice device ->
  SDataType dataType ->
  SSize rows ->
  SSize cols ->
  m (Tensor gradient layout device dataType shape)
sEye :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (rows :: Size Nat) (cols :: Size Nat)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
 shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize rows
-> SSize cols
-> m (Tensor gradient layout device dataType shape)
sEye SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SSize rows
rows SSize cols
cols = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType
      Int
rows' :: Int = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize rows
rows
      Int
cols' :: Int = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize cols
cols
  ForeignPtr Tensor
tPtr <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 Int64
-> Int64 -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.eye_llo Int
rows' Int
cols' TensorOptions
opts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr

-- | Create a typed rectangular tensor with ones on the diagonal and zeros elsewhere.
eye ::
  forall rows cols gradient layout device dataType shape m.
  ( MonadThrow m,
    shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols],
    SingI gradient,
    SingI layout,
    SingI device,
    SingI dataType,
    SingI rows,
    SingI cols
  ) =>
  m (Tensor gradient layout device dataType shape)
eye :: forall (rows :: Size Nat) (cols :: Size Nat)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
 shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols],
 SingI gradient, SingI layout, SingI device, SingI dataType,
 SingI rows, SingI cols) =>
m (Tensor gradient layout device dataType shape)
eye = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (rows :: Size Nat) (cols :: Size Nat)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
 shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize rows
-> SSize cols
-> m (Tensor gradient layout device dataType shape)
sEye (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @rows) (forall {k} (a :: k). SingI a => Sing a
sing @cols)

-- | Create a gradually typed square tensor with ones on the diagonal and zeros elsewhere.
sEyeSquare ::
  forall gradient layout device dataType size shape m.
  (MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size]) =>
  SGradient gradient ->
  SLayout layout ->
  SDevice device ->
  SDataType dataType ->
  SSize size ->
  m (Tensor gradient layout device dataType shape)
sEyeSquare :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (size :: Size Nat)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
 shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sEyeSquare SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SSize size
size = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType
      Int
size' :: Int = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize size
size
  ForeignPtr Tensor
tPtr <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 Int64 -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.eye_lo Int
size' TensorOptions
opts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr

-- | Create a typed square tensor with ones on the diagonal and zeros elsewhere.
eyeSquare ::
  forall size gradient layout device dataType shape m.
  ( MonadThrow m,
    shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size],
    SingI gradient,
    SingI layout,
    SingI device,
    SingI dataType,
    SingI size
  ) =>
  m (Tensor gradient layout device dataType shape)
eyeSquare :: forall (size :: Size Nat) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
 shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size],
 SingI gradient, SingI layout, SingI device, SingI dataType,
 SingI size) =>
m (Tensor gradient layout device dataType shape)
eyeSquare = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (size :: Size Nat)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
 shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sEyeSquare (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @size)