{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Torch.GraduallyTyped.Tensor.Creation
( sOnes,
ones,
sZeros,
zeros,
sFull,
full,
sRandn,
randn,
sArangeNaturals,
arangeNaturals,
sEye,
eye,
sEyeSquare,
eyeSquare,
)
where
import Control.Monad.Catch (MonadThrow)
import Data.Monoid (All (..))
import Data.Singletons (SingI (..), SingKind (fromSing))
import Torch.GraduallyTyped.DType (SDataType)
import Torch.GraduallyTyped.Device (SDevice)
import Torch.GraduallyTyped.Internal.TensorOptions (tensorDims, tensorOptions)
import Torch.GraduallyTyped.Layout (SLayout (..))
import Torch.GraduallyTyped.Prelude (forgetIsChecked)
import Torch.GraduallyTyped.Random (Generator (..), SGetGeneratorDevice, sCreateWithGenerator)
import Torch.GraduallyTyped.RequiresGradient (SGradient)
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SSize, Shape (..), dimName, dimSize)
import Torch.GraduallyTyped.Tensor.Type (Tensor (..), TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.Internal.Cast (cast2, cast3, cast4)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.TensorFactories as ATen
sOnes ::
forall gradient layout device dataType shape m.
MonadThrow m =>
TensorSpec gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
sOnes :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
..} = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
tsGradient SLayout layout
tsLayout SDevice device
tsDevice SDataType dataType
tsDataType
dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims SShape shape
tsShape
ForeignPtr Tensor
tPtr <- case (forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> name
dimName [Dim String Integer]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> size
dimSize [Dim String Integer]
dims) of
([String]
names, [Integer]
sizes)
| All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Bool -> All
All forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Eq a => a -> a -> Bool
== String
"*")) forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.ones_lo [Integer]
sizes TensorOptions
opts
| Bool
otherwise -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.ones_lNo [Integer]
sizes [String]
names TensorOptions
opts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr
ones ::
forall gradient layout device dataType shape m.
MonadThrow m =>
(SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape) =>
m (Tensor gradient layout device dataType shape)
ones :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, SingI gradient, SingI layout, SingI device,
SingI dataType, SingI shape) =>
m (Tensor gradient layout device dataType shape)
ones = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)
sZeros ::
forall gradient layout device dataType shape m.
MonadThrow m =>
TensorSpec gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape)
sZeros :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
tsGradient SLayout layout
tsLayout SDevice device
tsDevice SDataType dataType
tsDataType
dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims SShape shape
tsShape
ForeignPtr Tensor
tPtr <- case (forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> name
dimName [Dim String Integer]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> size
dimSize [Dim String Integer]
dims) of
([String]
names, [Integer]
sizes)
| All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Bool -> All
All forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Eq a => a -> a -> Bool
== String
"*")) forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.zeros_lo [Integer]
sizes TensorOptions
opts
| Bool
otherwise -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.zeros_lNo [Integer]
sizes [String]
names TensorOptions
opts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr
zeros ::
forall gradient layout device dataType shape m.
MonadThrow m =>
(SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape) =>
m (Tensor gradient layout device dataType shape)
zeros :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, SingI gradient, SingI layout, SingI device,
SingI dataType, SingI shape) =>
m (Tensor gradient layout device dataType shape)
zeros = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)
sFull ::
forall gradient layout device dataType shape input m.
(MonadThrow m, Scalar input) =>
TensorSpec gradient layout device dataType shape ->
input ->
m (Tensor gradient layout device dataType shape)
sFull :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
(m :: * -> *).
(MonadThrow m, Scalar input) =>
TensorSpec gradient layout device dataType shape
-> input -> m (Tensor gradient layout device dataType shape)
sFull TensorSpec {SLayout layout
SDevice device
SDataType dataType
SGradient gradient
SShape shape
tsShape :: SShape shape
tsDataType :: SDataType dataType
tsDevice :: SDevice device
tsLayout :: SLayout layout
tsGradient :: SGradient gradient
tsShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
..} input
input = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
tsGradient SLayout layout
tsLayout SDevice device
tsDevice SDataType dataType
tsDataType
dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims SShape shape
tsShape
ForeignPtr Tensor
tPtr <- case (forall name size. Dim name size -> name
dimName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim String Integer]
dims, forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim String Integer]
dims) of
([String]
names, [Integer]
sizes)
| All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\String
name -> Bool -> All
All forall a b. (a -> b) -> a -> b
$ String
name forall a. Eq a => a -> a -> Bool
== String
"*") forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.full_lso [Integer]
sizes input
input TensorOptions
opts
| Bool
otherwise -> forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr IntArray
-> ForeignPtr Scalar
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.full_lsNo [Integer]
sizes input
input [String]
names TensorOptions
opts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr
full ::
forall gradient layout device dataType shape input m.
(MonadThrow m, SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape, Scalar input) =>
input ->
m (Tensor gradient layout device dataType shape)
full :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
(m :: * -> *).
(MonadThrow m, SingI gradient, SingI layout, SingI device,
SingI dataType, SingI shape, Scalar input) =>
input -> m (Tensor gradient layout device dataType shape)
full = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
(m :: * -> *).
(MonadThrow m, Scalar input) =>
TensorSpec gradient layout device dataType shape
-> input -> m (Tensor gradient layout device dataType shape)
sFull forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)
sRandn ::
forall gradient layout device dataType shape generatorDevice m.
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape ->
Generator generatorDevice ->
m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
sRandn :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn TensorSpec gradient layout device dataType shape
tSpec Generator generatorDevice
generator =
forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> (ForeignPtr TensorOptions
-> [Dim String Integer]
-> ForeignPtr Generator
-> IO (ForeignPtr Tensor))
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sCreateWithGenerator TensorSpec gradient layout device dataType shape
tSpec Generator generatorDevice
generator forall a b. (a -> b) -> a -> b
$
\ForeignPtr TensorOptions
opts [Dim String Integer]
dims ForeignPtr Generator
genPtr ->
case (forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> name
dimName [Dim String Integer]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall name size. Dim name size -> size
dimSize [Dim String Integer]
dims) of
([String]
names, [Integer]
sizes)
| All -> Bool
getAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\String
name -> Bool -> All
All forall a b. (a -> b) -> a -> b
$ String
name forall a. Eq a => a -> a -> Bool
== String
"*") forall a b. (a -> b) -> a -> b
$ [String]
names -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr IntArray
-> ForeignPtr Generator
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.randn_lGo [Integer]
sizes ForeignPtr Generator
genPtr ForeignPtr TensorOptions
opts
| Bool
otherwise -> forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr IntArray
-> ForeignPtr Generator
-> ForeignPtr DimnameList
-> ForeignPtr TensorOptions
-> IO (ForeignPtr Tensor)
ATen.randn_lGNo [Integer]
sizes ForeignPtr Generator
genPtr [String]
names ForeignPtr TensorOptions
opts
randn ::
forall gradient layout device dataType shape generatorDevice m.
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
(SingI gradient, SingI layout, SingI device, SingI dataType, SingI shape) =>
Generator generatorDevice ->
m (Tensor gradient layout (device <+> generatorDevice) dataType shape, Generator (device <+> generatorDevice))
randn :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m, SingI gradient,
SingI layout, SingI device, SingI dataType, SingI shape) =>
Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
randn = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @shape)
sArangeNaturals ::
forall m gradient layout device dataType size shape.
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
SGradient gradient ->
SLayout layout ->
SDevice device ->
SDataType dataType ->
SSize size ->
m (Tensor gradient layout device dataType shape)
sArangeNaturals :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (size :: Size Nat)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sArangeNaturals SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SSize size
size = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType
size' :: Integer
size' = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize size
size
ForeignPtr Tensor
tPtr <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Scalar
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.arange_so Integer
size' TensorOptions
opts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr
arangeNaturals ::
forall size gradient layout device dataType shape m.
( MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") size],
SingI gradient,
SingI layout,
SingI device,
SingI dataType,
SingI size
) =>
m (Tensor gradient layout device dataType shape)
arangeNaturals :: forall (size :: Size Nat) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size],
SingI gradient, SingI layout, SingI device, SingI dataType,
SingI size) =>
m (Tensor gradient layout device dataType shape)
arangeNaturals = forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (size :: Size Nat)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sArangeNaturals (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @size)
sEye ::
forall gradient layout device dataType rows cols shape m.
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols]) =>
SGradient gradient ->
SLayout layout ->
SDevice device ->
SDataType dataType ->
SSize rows ->
SSize cols ->
m (Tensor gradient layout device dataType shape)
sEye :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (rows :: Size Nat) (cols :: Size Nat)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize rows
-> SSize cols
-> m (Tensor gradient layout device dataType shape)
sEye SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SSize rows
rows SSize cols
cols = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType
Int
rows' :: Int = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize rows
rows
Int
cols' :: Int = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize cols
cols
ForeignPtr Tensor
tPtr <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 Int64
-> Int64 -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.eye_llo Int
rows' Int
cols' TensorOptions
opts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr
eye ::
forall rows cols gradient layout device dataType shape m.
( MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols],
SingI gradient,
SingI layout,
SingI device,
SingI dataType,
SingI rows,
SingI cols
) =>
m (Tensor gradient layout device dataType shape)
eye :: forall (rows :: Size Nat) (cols :: Size Nat)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols],
SingI gradient, SingI layout, SingI device, SingI dataType,
SingI rows, SingI cols) =>
m (Tensor gradient layout device dataType shape)
eye = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (rows :: Size Nat) (cols :: Size Nat)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") rows, 'Dim ('Name "*") cols]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize rows
-> SSize cols
-> m (Tensor gradient layout device dataType shape)
sEye (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @rows) (forall {k} (a :: k). SingI a => Sing a
sing @cols)
sEyeSquare ::
forall gradient layout device dataType size shape m.
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size]) =>
SGradient gradient ->
SLayout layout ->
SDevice device ->
SDataType dataType ->
SSize size ->
m (Tensor gradient layout device dataType shape)
sEyeSquare :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (size :: Size Nat)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sEyeSquare SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType SSize size
size = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
let opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient SLayout layout
layout SDevice device
device SDataType dataType
dataType
Int
size' :: Int = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSize size
size
ForeignPtr Tensor
tPtr <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 Int64 -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.eye_lo Int
size' TensorOptions
opts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
tPtr
eyeSquare ::
forall size gradient layout device dataType shape m.
( MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size],
SingI gradient,
SingI layout,
SingI device,
SingI dataType,
SingI size
) =>
m (Tensor gradient layout device dataType shape)
eyeSquare :: forall (size :: Size Nat) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size],
SingI gradient, SingI layout, SingI device, SingI dataType,
SingI size) =>
m (Tensor gradient layout device dataType shape)
eyeSquare = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (size :: Size Nat)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m,
shape ~ 'Shape '[ 'Dim ('Name "*") size, 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sEyeSquare (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType) (forall {k} (a :: k). SingI a => Sing a
sing @size)