{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE ViewPatterns #-}

module Torch.GraduallyTyped.Tensor.Type where

import Control.Applicative (empty)
import Control.Category ((>>>))
import Control.Exception (Exception (..))
import Control.Monad (forM, forM_, when, (<=<), (>=>))
import Control.Monad.Catch (MonadThrow (..))
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Functor ((<&>))
import Data.Int (Int16)
import Data.List.NonEmpty (NonEmpty ((:|)), nonEmpty, unzip)
import Data.Maybe (maybeToList)
import Data.Proxy (Proxy (..))
import Data.Singletons (SingI (sing), fromSing)
import Data.Typeable (Typeable)
import qualified Data.Vector as V hiding (uncons)
import qualified Data.Vector.Generic.Sized.Internal as SVI
import qualified Data.Vector.Sized as SV
import Foreign (Ptr, Word8, castPtr, fromBool, peekElemOff, pokeElemOff, withForeignPtr)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.Generics (Generic)
import GHC.TypeLits (KnownNat, KnownSymbol, Nat, Symbol, natVal, symbolVal)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..), SDeviceType (..))
import Torch.GraduallyTyped.Internal.TensorOptions (tensorOptions)
import qualified Torch.GraduallyTyped.Internal.Vector as V
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, ifM, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (InsertDimF, ReplaceDimF)
import Torch.GraduallyTyped.Shape.Type (By (ByIndex), Dim (..), Name (..), SDim (..), SName (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.HList (HList (..), pattern (:.))
import Torch.Internal.Cast (cast0, cast1, cast2, cast4)
import Torch.Internal.Class (Castable (..))
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.Extra as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen (Tensor, TensorList, TensorOptions)
import qualified Torch.Internal.Unmanaged.Type.Tensor as Unmanaged (tensor_data_ptr)
import qualified Torch.Tensor (Tensor (Unsafe))
import Prelude hiding (unzip, unzip3)

-- $setup
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped

-- | A gradually typed tensor.
--
-- @
--                          ┌─► Compute device, e.g. `'Device 'CPU`
--                          │
--                          │               ┌─► List of dimensions, e.g. `'Shape '[ 'Dim 'UncheckedName ('Size 8), 'Dim 'UncheckedName ('Size 1) ]`
--                          │               │
-- Tensor gradient layout device dataType shape
--           │       │              │
--           │       │              └─► Data type, e.g. `'DataType 'Float`
--           │       │
--           │       └─► Memory layout, e.g. `'Layout 'Dense`
--           │
--           └─► Whether or not the tensor requires a gradient, e.g. `'Gradient 'WithGradient` for one that does
-- @
newtype
  Tensor
    (gradient :: Gradient RequiresGradient)
    (layout :: Layout LayoutType)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (shape :: Shape [Dim (Name Symbol) (Size Nat)])
  where
  -- | Unsafe constructor for tensors.
  -- Do not call this constructor directly,
  -- use smart constructors like 'ones' or 'randn' instead.
  UnsafeTensor ::
    forall gradient layout device dataType shape.
    ForeignPtr ATen.Tensor ->
    Tensor gradient layout device dataType shape

type role Tensor nominal nominal nominal nominal nominal

instance Show (Tensor gradient layout device dataType shape) where
  show :: Tensor gradient layout device dataType shape -> String
show (UnsafeTensor ForeignPtr Tensor
t) = forall a. Show a => a -> String
show (ForeignPtr Tensor -> Tensor
Torch.Tensor.Unsafe ForeignPtr Tensor
t)

data
  TensorSpec
    (gradient :: Gradient RequiresGradient)
    (layout :: Layout LayoutType)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (shape :: Shape [Dim (Name Symbol) (Size Nat)])
  where
  TensorSpec ::
    forall gradient layout device dataType shape.
    { forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
tsGradient :: SGradient gradient,
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsLayout :: SLayout layout,
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsDevice :: SDevice device,
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDataType :: SDataType dataType,
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsShape :: SShape shape
    } ->
    TensorSpec gradient layout device dataType shape
  deriving stock (Int -> TensorSpec gradient layout device dataType shape -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int -> TensorSpec gradient layout device dataType shape -> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
[TensorSpec gradient layout device dataType shape] -> ShowS
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> String
showList :: [TensorSpec gradient layout device dataType shape] -> ShowS
$cshowList :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
[TensorSpec gradient layout device dataType shape] -> ShowS
show :: TensorSpec gradient layout device dataType shape -> String
$cshow :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> String
showsPrec :: Int -> TensorSpec gradient layout device dataType shape -> ShowS
$cshowsPrec :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int -> TensorSpec gradient layout device dataType shape -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep (TensorSpec gradient layout device dataType shape) x
-> TensorSpec gradient layout device dataType shape
forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
TensorSpec gradient layout device dataType shape
-> Rep (TensorSpec gradient layout device dataType shape) x
$cto :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep (TensorSpec gradient layout device dataType shape) x
-> TensorSpec gradient layout device dataType shape
$cfrom :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
TensorSpec gradient layout device dataType shape
-> Rep (TensorSpec gradient layout device dataType shape) x
Generic)

-- | Alias for an untyped tensor without gradients.
type UncheckedTensor = Tensor 'UncheckedGradient 'UncheckedLayout 'UncheckedDevice 'UncheckedDataType 'UncheckedShape

-- | Alias for an untyped tensor with gradients.
type UncheckedParameter = Tensor ('Gradient 'WithGradient) 'UncheckedLayout 'UncheckedDevice 'UncheckedDataType 'UncheckedShape

-- | Alias for a tensor on CPU memory without gradients.
type CPUTensor = Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) ('Device 'CPU)

-- | Alias for a tensor on CPU memory with gradients.
type CPUParameter = Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU)

-- | Alias for a sparse tensor on CPU memory without gradients.
type SparseCPUTensor = Tensor ('Gradient 'WithoutGradient) ('Layout 'Sparse) ('Device 'CPU)

-- | Alias for a sparse tensor on CPU memory with gradients.
type SparseCPUParameter = Tensor ('Gradient 'WithGradient) ('Layout 'Sparse) ('Device 'CPU)

-- | Alias for a tensor on CUDA memory without gradients.
type CUDATensor deviceId = Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) ('Device ('CUDA deviceId))

-- | Alias for a tensor on CUDA memory with gradients.
type CUDAParameter deviceId = Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device ('CUDA deviceId))

-- | Alias for a sparse tensor on CUDA memory without gradients.
type SparseCUDATensor deviceId = Tensor ('Gradient 'WithoutGradient) ('Layout 'Sparse) ('Device ('CUDA deviceId))

-- | Alias for a sparse tensor on CUDA memory with gradients.
type SparseCUDAParameter deviceId = Tensor ('Gradient 'WithGradient) ('Layout 'Sparse) ('Device ('CUDA deviceId))

instance Num (Tensor gradient layout device dataType shape) where
  + :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
(+) = (forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.add_tt
  (-) = (forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sub_tt
  * :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
(*) = (forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mul_tt
  negate :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
negate = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.neg_t
  abs :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
abs = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.abs_t
  signum :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
signum = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sign_t
  fromInteger :: Integer -> Tensor gradient layout device dataType shape
fromInteger Integer
_a = forall a. HasCallStack => a
undefined

instance
  Castable
    (Tensor gradient layout device dataType shape)
    (ForeignPtr ATen.Tensor)
  where
  cast :: forall r.
Tensor gradient layout device dataType shape
-> (ForeignPtr Tensor -> IO r) -> IO r
cast (UnsafeTensor ForeignPtr Tensor
atenTensor) ForeignPtr Tensor -> IO r
f = ForeignPtr Tensor -> IO r
f ForeignPtr Tensor
atenTensor
  uncast :: forall r.
ForeignPtr Tensor
-> (Tensor gradient layout device dataType shape -> IO r) -> IO r
uncast ForeignPtr Tensor
atenTensor Tensor gradient layout device dataType shape -> IO r
f = Tensor gradient layout device dataType shape -> IO r
f forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor ForeignPtr Tensor
atenTensor

instance
  Castable
    [Tensor gradient layout device dataType shape]
    (ForeignPtr ATen.TensorList)
  where
  cast :: forall r.
[Tensor gradient layout device dataType shape]
-> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor gradient layout device dataType shape]
xs ForeignPtr TensorList -> IO r
f = do
    [ForeignPtr Tensor]
ptrList <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor gradient layout device dataType shape
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor gradient layout device dataType shape
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor gradient layout device dataType shape]
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptrList ForeignPtr TensorList -> IO r
f
  uncast :: forall r.
ForeignPtr TensorList
-> ([Tensor gradient layout device dataType shape] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor gradient layout device dataType shape] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptrList -> do
    [Tensor gradient layout device dataType shape]
tensorList <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptrList
    [Tensor gradient layout device dataType shape] -> IO r
f [Tensor gradient layout device dataType shape]
tensorList

instance Castable (HList '[]) [ForeignPtr ATen.Tensor] where
  cast :: forall r. HList '[] -> ([ForeignPtr Tensor] -> IO r) -> IO r
cast HList '[]
R:HListk[] k
HNil [ForeignPtr Tensor] -> IO r
f = [ForeignPtr Tensor] -> IO r
f []
  uncast :: forall r. [ForeignPtr Tensor] -> (HList '[] -> IO r) -> IO r
uncast [] HList '[] -> IO r
f = HList '[] -> IO r
f forall k. HList '[]
HNil
  uncast (ForeignPtr Tensor
_ : [ForeignPtr Tensor]
_) HList '[] -> IO r
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"The list of tensors has more elements than expected. This means that the runtime length of the list exceeded its compile-time length."

instance
  ( Castable (HList tensors) [ForeignPtr ATen.Tensor]
  ) =>
  Castable (HList (Tensor gradient layout device dataType shape ': tensors)) [ForeignPtr ATen.Tensor]
  where
  cast :: forall r.
HList (Tensor gradient layout device dataType shape : tensors)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
cast (HCons (Tensor gradient layout device dataType shape
tensor, HList tensors
tensors)) [ForeignPtr Tensor] -> IO r
f = do
    ForeignPtr Tensor
ptr <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor gradient layout device dataType shape
tensor forall (f :: * -> *) a. Applicative f => a -> f a
pure
    [ForeignPtr Tensor]
ptrList <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList tensors
tensors forall (f :: * -> *) a. Applicative f => a -> f a
pure
    [ForeignPtr Tensor] -> IO r
f (ForeignPtr Tensor
ptr forall a. a -> [a] -> [a]
: [ForeignPtr Tensor]
ptrList)
  uncast :: forall r.
[ForeignPtr Tensor]
-> (HList (Tensor gradient layout device dataType shape : tensors)
    -> IO r)
-> IO r
uncast [] HList (Tensor gradient layout device dataType shape : tensors)
-> IO r
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"The list of tensors ended prematurely. This means that the runtime length of the list was smaller than its compile-time length."
  uncast (ForeignPtr Tensor
ptr : [ForeignPtr Tensor]
ptrList) HList (Tensor gradient layout device dataType shape : tensors)
-> IO r
f = do
    Tensor gradient layout device dataType shape
tensor <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
ptr forall (f :: * -> *) a. Applicative f => a -> f a
pure
    HList tensors
tensors <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ForeignPtr Tensor]
ptrList forall (f :: * -> *) a. Applicative f => a -> f a
pure
    HList (Tensor gradient layout device dataType shape : tensors)
-> IO r
f (Tensor gradient layout device dataType shape
tensor forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList tensors
tensors)

instance
  Castable (HList l) [ForeignPtr ATen.Tensor] =>
  Castable (HList l) (ForeignPtr ATen.TensorList)
  where
  cast :: forall r. HList l -> (ForeignPtr TensorList -> IO r) -> IO r
cast HList l
xs ForeignPtr TensorList -> IO r
f = do
    [ForeignPtr Tensor]
ts <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList l
xs forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [ForeignPtr ATen.Tensor]
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ts ForeignPtr TensorList -> IO r
f
  uncast :: forall r. ForeignPtr TensorList -> (HList l -> IO r) -> IO r
uncast ForeignPtr TensorList
xs HList l -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \([ForeignPtr Tensor]
ptrList :: [ForeignPtr ATen.Tensor]) -> do
    HList l
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ForeignPtr Tensor]
ptrList forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (HList l)
    HList l -> IO r
f HList l
ts

-- | Takes a tensor that may or may not require gradient computations
-- and returns a copy that does not require them.
withoutGradient ::
  forall gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | copy of the input tensor without gradient computations turned off.
  IO (Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
     (Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient Tensor gradient layout device dataType shape
tensor = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient layout device dataType shape
tensor Bool
False

-- | Takes a tensor that does not requires gradient computations
-- and returns a copy that requires them.
withGradient ::
  forall gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | copy of the input tensor with gradient computations turned on.
  IO (Tensor ('Gradient 'WithGradient) layout device dataType shape)
withGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
     (Tensor ('Gradient 'WithGradient) layout device dataType shape)
withGradient Tensor gradient layout device dataType shape
tensor = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient layout device dataType shape
tensor Bool
True

-- | Turn gradient computations off or on for a tensor.
sSetGradient ::
  forall gradient gradient' layout device dataType shape.
  SGradient gradient ->
  Tensor gradient' layout device dataType shape ->
  IO (Tensor gradient layout device dataType shape)
sSetGradient :: forall (gradient :: Gradient RequiresGradient)
       (gradient' :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> Tensor gradient' layout device dataType shape
-> IO (Tensor gradient layout device dataType shape)
sSetGradient SGradient gradient
gradient Tensor gradient' layout device dataType shape
tensor =
  case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SGradient gradient
gradient) of
    RequiresGradient
WithoutGradient -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient' layout device dataType shape
tensor Bool
False
    RequiresGradient
WithGradient -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient' layout device dataType shape
tensor Bool
True

class SGetGradient (gradient :: Gradient RequiresGradient) where
  -- | Returns the gradually typed information for whether or not gradient computations for the tensor are turned on.
  --
  -- >>> sOnes' gradient = sOnes $ TensorSpec gradient (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- sOnes' $ SGradient SWithGradient
  -- >>> sGetGradient t
  -- SGradient SWithGradient
  -- >>> t <- sOnes' $ SUncheckedGradient WithoutGradient
  -- >>> sGetGradient t
  -- SUncheckedGradient WithoutGradient
  sGetGradient ::
    forall layout device dataType shape.
    -- | input tensor
    Tensor gradient layout device dataType shape ->
    -- | information about whether or not gradient computations are required
    SGradient gradient

  -- | Returns the untyped memory layout of the input tensor.
  --
  -- >>> sOnes' gradient = sOnes $ TensorSpec gradient (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- sOnes' $ SGradient SWithGradient
  -- >>> getRequiresGradient t
  -- WithGradient
  -- >>> t <- sOnes' $ SUncheckedGradient WithoutGradient
  -- >>> getRequiresGradient t
  -- WithoutGradient
  getRequiresGradient ::
    forall layout device dataType shape.
    -- | input tensor
    Tensor gradient layout device dataType shape ->
    -- | information about whether or not gradient computations are required
    RequiresGradient
  getRequiresGradient Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetGradient gradient =>
Tensor gradient layout device dataType shape -> SGradient gradient
sGetGradient Tensor gradient layout device dataType shape
tensor

instance SGetGradient 'UncheckedGradient where
  sGetGradient :: forall (layout :: Layout LayoutType)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor 'UncheckedGradient layout device dataType shape
-> SGradient 'UncheckedGradient
sGetGradient Tensor 'UncheckedGradient layout device dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor 'UncheckedGradient layout device dataType shape
tensor) = RequiresGradient -> SGradient 'UncheckedGradient
SUncheckedGradient RequiresGradient
WithGradient
    | Bool
otherwise = RequiresGradient -> SGradient 'UncheckedGradient
SUncheckedGradient RequiresGradient
WithoutGradient

instance SGetGradient ('Gradient 'WithGradient) where
  sGetGradient :: forall (layout :: Layout LayoutType)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor ('Gradient 'WithGradient) layout device dataType shape
-> SGradient ('Gradient 'WithGradient)
sGetGradient Tensor ('Gradient 'WithGradient) layout device dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor ('Gradient 'WithGradient) layout device dataType shape
tensor) = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient
    | Bool
otherwise =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"The tensor should require gradient computations but doesn't. "
          forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

instance SGetGradient ('Gradient 'WithoutGradient) where
  sGetGradient :: forall (layout :: Layout LayoutType)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor ('Gradient 'WithoutGradient) layout device dataType shape
-> SGradient ('Gradient 'WithoutGradient)
sGetGradient Tensor ('Gradient 'WithoutGradient) layout device dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor ('Gradient 'WithoutGradient) layout device dataType shape
tensor) =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"The tensor should not require gradient computations but does. "
          forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
    | Bool
otherwise = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient

data GradientError = GradientError {GradientError -> RequiresGradient
geExpected :: RequiresGradient, GradientError -> RequiresGradient
geActual :: RequiresGradient}
  deriving stock (Int -> GradientError -> ShowS
[GradientError] -> ShowS
GradientError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GradientError] -> ShowS
$cshowList :: [GradientError] -> ShowS
show :: GradientError -> String
$cshow :: GradientError -> String
showsPrec :: Int -> GradientError -> ShowS
$cshowsPrec :: Int -> GradientError -> ShowS
Show, Typeable)

instance Exception GradientError where
  displayException :: GradientError -> String
displayException GradientError {RequiresGradient
geActual :: RequiresGradient
geExpected :: RequiresGradient
geActual :: GradientError -> RequiresGradient
geExpected :: GradientError -> RequiresGradient
..} =
    String
"The tensor's information about whether or not gradient computations are required reads `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RequiresGradient
geActual
      forall a. Semigroup a => a -> a -> a
<> String
"` instead of `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RequiresGradient
geExpected
      forall a. Semigroup a => a -> a -> a
<> String
"`."

-- | Checks whether or not gradient computations are required for a tensor
-- and returns a statically annotated copy of it wrapped in a 'MonadThrow' 'm'.
--
-- For instance, if 'm' is 'Maybe', then the result will be wrapped in 'Just'
-- if and only if gradients are computed for the tensor according to the argument @gradient@.
-- If gradients are expected but none are computed, then the result will be 'Nothing'.
-- If gradients are not expected but are computed, then the result will be 'Nothing' as well.
--
-- In the REPL, 'm' will default to 'IO':
-- >>> t <- sOnes $ TensorSpec (SUncheckedGradient WithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
-- >>> t' <- sCheckedGradient (SGradient SWithGradient) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
-- >>> t' <- sCheckedGradient (SGradient SWithoutGradient) t
-- *** Exception: GradientError {geExpected = WithoutGradient, geActual = WithGradient}
sCheckedGradient ::
  forall gradient' m gradient layout device dataType shape.
  (SGetGradient gradient, MonadThrow m, Catch (gradient <+> gradient')) =>
  -- | layout
  SGradient gradient' ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient' layout device dataType shape)
sCheckedGradient :: forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
 Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient SGradient gradient'
gradient' Tensor gradient layout device dataType shape
tensor =
  let actualGradient :: RequiresGradient
actualGradient = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetGradient gradient =>
Tensor gradient layout device dataType shape -> SGradient gradient
sGetGradient Tensor gradient layout device dataType shape
tensor
      expectedGradient :: RequiresGradient
expectedGradient = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SGradient gradient'
gradient'
   in if RequiresGradient
actualGradient forall a. Eq a => a -> a -> Bool
== RequiresGradient
expectedGradient
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
        else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ RequiresGradient -> RequiresGradient -> GradientError
GradientError RequiresGradient
expectedGradient RequiresGradient
actualGradient

checkedGradient ::
  forall gradient' m gradient layout device dataType shape.
  (SingI gradient', SGetGradient gradient, MonadThrow m, Catch (gradient <+> gradient')) =>
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient' layout device dataType shape)
checkedGradient :: forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI gradient', SGetGradient gradient, MonadThrow m,
 Catch (gradient <+> gradient')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
checkedGradient = forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
 Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall {k} (a :: k). SingI a => Sing a
sing @gradient')

-- | Returns the input tensor but with 'UncheckedGradeint' as gradient type annotation.
-- Any static information about whether or not the gradient computation is required for the tensor is lost.
-- However, the tensor's underlying data structure is not changed.
--
-- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
-- >>> :type uncheckedGradient t
-- uncheckedGradient t
--   :: Tensor
--        'UncheckedGradient
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
uncheckedGradient ::
  forall gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | tensor without checked layout
  Tensor 'UncheckedGradient layout device dataType shape
uncheckedGradient :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor 'UncheckedGradient layout device dataType shape
uncheckedGradient = coerce :: forall a b. Coercible a b => a -> b
coerce

-- | Returns a dense copy of the tensor.
toDense ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | dense copy
  m (Tensor gradient ('Layout 'Dense) device dataType shape)
toDense :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient ('Layout 'Dense) device dataType shape)
toDense = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense

-- | Returns a sparse copy of the tensor.
toSparse ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | sparse copy
  m (Tensor gradient ('Layout 'Sparse) device dataType shape)
toSparse :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient ('Layout 'Sparse) device dataType shape)
toSparse = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse

-- | Set the memory layout of a tensor to a given layout.
sSetLayout ::
  forall m gradient layout layout' device dataType shape.
  MonadThrow m =>
  SLayout layout ->
  Tensor gradient layout' device dataType shape ->
  m (Tensor gradient layout device dataType shape)
sSetLayout :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (layout' :: Layout LayoutType)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SLayout layout
-> Tensor gradient layout' device dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetLayout SLayout layout
layout Tensor gradient layout' device dataType shape
input =
  case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SLayout layout
layout) of
    LayoutType
Dense -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense forall a b. (a -> b) -> a -> b
$ Tensor gradient layout' device dataType shape
input
    LayoutType
Sparse -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse forall a b. (a -> b) -> a -> b
$ Tensor gradient layout' device dataType shape
input

class SGetLayout (layout :: Layout LayoutType) where
  -- | Returns the gradually typed memory layout of the input tensor.
  --
  -- >>> sOnes' layout = sOnes $ TensorSpec (SGradient SWithGradient) layout (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- sOnes' $ SLayout SDense
  -- >>> sGetLayout t
  -- SLayout SDense
  -- >>> t <- sOnes' $ SUncheckedLayout Dense
  -- >>> sGetLayout t
  -- SUncheckedLayout Dense
  sGetLayout ::
    forall gradient device dataType shape.
    -- | input
    Tensor gradient layout device dataType shape ->
    -- | memory layout
    SLayout layout

  -- | Returns the untyped memory layout of the input tensor.
  --
  -- >>> sOnes' layout = sOnes $ TensorSpec (SGradient SWithGradient) layout (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- sOnes' $ SLayout SDense
  -- >>> getLayoutType t
  -- Dense
  -- >>> t <- sOnes' $ SUncheckedLayout Dense
  -- >>> getLayoutType t
  -- Dense
  getLayoutType ::
    forall gradient device dataType shape.
    -- | input
    Tensor gradient layout device dataType shape ->
    -- | memory layout
    LayoutType
  getLayoutType Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
tensor

instance SGetLayout 'UncheckedLayout where
  sGetLayout :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient 'UncheckedLayout device dataType shape
-> SLayout 'UncheckedLayout
sGetLayout Tensor gradient 'UncheckedLayout device dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_sparse Tensor gradient 'UncheckedLayout device dataType shape
tensor) = LayoutType -> SLayout 'UncheckedLayout
SUncheckedLayout LayoutType
Sparse
    | Bool
otherwise = LayoutType -> SLayout 'UncheckedLayout
SUncheckedLayout LayoutType
Dense

instance SGetLayout ('Layout 'Sparse) where
  sGetLayout :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Sparse) device dataType shape
-> SLayout ('Layout 'Sparse)
sGetLayout Tensor gradient ('Layout 'Sparse) device dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_sparse Tensor gradient ('Layout 'Sparse) device dataType shape
tensor) = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Sparse
SSparse
    | Bool
otherwise =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"The tensor should be sparse but isn't. "
          forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

instance SGetLayout ('Layout 'Dense) where
  sGetLayout :: forall (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Dense) device dataType shape
-> SLayout ('Layout 'Dense)
sGetLayout Tensor gradient ('Layout 'Dense) device dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_sparse Tensor gradient ('Layout 'Dense) device dataType shape
tensor) =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"The tensor should be dense but isn't. "
          forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
    | Bool
otherwise = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense

data LayoutError = LayoutError {LayoutError -> LayoutType
leExpected :: LayoutType, LayoutError -> LayoutType
leActual :: LayoutType}
  deriving stock (Int -> LayoutError -> ShowS
[LayoutError] -> ShowS
LayoutError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LayoutError] -> ShowS
$cshowList :: [LayoutError] -> ShowS
show :: LayoutError -> String
$cshow :: LayoutError -> String
showsPrec :: Int -> LayoutError -> ShowS
$cshowsPrec :: Int -> LayoutError -> ShowS
Show, Typeable)

instance Exception LayoutError where
  displayException :: LayoutError -> String
displayException LayoutError {LayoutType
leActual :: LayoutType
leExpected :: LayoutType
leActual :: LayoutError -> LayoutType
leExpected :: LayoutError -> LayoutType
..} =
    String
"The tensor does not have the memory layout `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LayoutType
leExpected
      forall a. Semigroup a => a -> a -> a
<> String
"` but `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LayoutType
leActual
      forall a. Semigroup a => a -> a -> a
<> String
"`."

-- | Checks whether or not the input tensor has the memory layout 'layout'
-- and returns a statically annotated copy of it wrapped in a 'MonadThrow' 'm'.
--
-- For instance, if 'm' is 'Maybe', then the result will be wrapped in 'Just' if and only if the tensor has indeed the memory layout 'layout'.
-- If it does not have it, then the result will be 'Nothing'.
--
-- In the REPL, 'm' will default to 'IO':
-- >>> t <- sOnes $ TensorSpec (SGradient SWithGradient) (SUncheckedLayout Dense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
-- >>> t' <- sCheckedLayout (SLayout SDense) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
-- >>> t' <- sCheckedLayout (SLayout SSparse) t
-- *** Exception: LayoutError {leExpected = Sparse, leActual = Dense}
sCheckedLayout ::
  forall layout' m gradient layout device dataType shape.
  (SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
  -- | layout
  SLayout layout' ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout' device dataType shape)
sCheckedLayout :: forall (layout' :: Layout LayoutType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
SLayout layout'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout' device dataType shape)
sCheckedLayout SLayout layout'
layout' Tensor gradient layout device dataType shape
tensor =
  let actualLayout :: LayoutType
actualLayout = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
tensor
      expectedLayout :: LayoutType
expectedLayout = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SLayout layout'
layout'
   in if LayoutType
actualLayout forall a. Eq a => a -> a -> Bool
== LayoutType
expectedLayout
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
        else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ LayoutType -> LayoutType -> LayoutError
LayoutError LayoutType
expectedLayout LayoutType
actualLayout

checkedLayout ::
  forall layout' m gradient layout device dataType shape.
  (SingI layout', SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout' device dataType shape)
checkedLayout :: forall (layout' :: Layout LayoutType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI layout', SGetLayout layout, MonadThrow m,
 Catch (layout <+> layout')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout' device dataType shape)
checkedLayout = forall (layout' :: Layout LayoutType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
SLayout layout'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout' device dataType shape)
sCheckedLayout (forall {k} (a :: k). SingI a => Sing a
sing @layout')

-- | Returns the input tensor but with 'UncheckedLayout' as memory layout type annotation.
-- Any static information about the tensor's memory layout is thus erased.
-- However, the tensor's underlying data structure is not changed.
--
-- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
-- >>> :type uncheckedLayout t
-- uncheckedLayout t
--   :: Tensor
--        ('Gradient 'WithGradient)
--        'UncheckedLayout
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
uncheckedLayout ::
  forall gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | tensor without checked layout
  Tensor gradient 'UncheckedLayout device dataType shape
uncheckedLayout :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient 'UncheckedLayout device dataType shape
uncheckedLayout = coerce :: forall a b. Coercible a b => a -> b
coerce

-- | Returns a copy of the tensor in CPU memory.
cpu ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | copy in CPU memory
  m (Tensor gradient layout ('Device 'CPU) dataType shape)
cpu :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout ('Device 'CPU) dataType shape)
cpu = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu

-- | Returns a copy of the tensor in CUDA memory.
cuda ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | copy in CUDA memory
  m (Tensor gradient layout ('Device ('CUDA 0)) dataType shape)
cuda :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout ('Device ('CUDA 0)) dataType shape)
cuda = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda

-- | Reallocates a tensor on the specified device.
sSetDevice ::
  forall m gradient layout device device' dataType shape.
  MonadThrow m =>
  SDevice device ->
  Tensor gradient layout device' dataType shape ->
  m (Tensor gradient layout device dataType shape)
sSetDevice :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDevice device
-> Tensor gradient layout device' dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetDevice SDevice device
device Tensor gradient layout device' dataType shape
input =
  case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDevice device
device) of
    DeviceType Int16
CPU ->
      forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device' dataType shape
input
    CUDA Int16
0 ->
      forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device' dataType shape
input
    CUDA Int16
idx ->
      forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
        ForeignPtr TensorOptions
opts :: ForeignPtr ATen.TensorOptions <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options Tensor gradient layout device' dataType shape
input
        ForeignPtr TensorOptions
opts' :: ForeignPtr ATen.TensorOptions <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s ForeignPtr TensorOptions
opts Int16
idx
        forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor gradient layout device' dataType shape
input ForeignPtr TensorOptions
opts' Bool
nonBlocking Bool
copy
      where
        nonBlocking :: Bool
nonBlocking = Bool
False
        copy :: Bool
copy = Bool
False

class SGetDevice (device :: Device (DeviceType Nat)) where
  -- | Returns the gradually typed compute device of the input tensor.
  --
  -- >>> ones' device = sOnes $ TensorSpec (SGradient SWithGradient) (SLayout SDense) device (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- ones' $ SDevice SCPU
  -- >>> sGetDevice t
  -- SDevice SCPU
  -- >>> t <- ones' $ SUncheckedDevice CPU
  -- >>> sGetDevice t
  -- SUncheckedDevice CPU
  sGetDevice ::
    forall gradient layout dataType shape.
    -- | input
    Tensor gradient layout device dataType shape ->
    -- | compute device of the input tensor
    SDevice device

  -- | Returns the untyped compute device of the input tensor.
  --
  -- >>> ones' device = sOnes $ TensorSpec (SGradient SWithGradient) (SLayout SDense) device (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- ones' $ SDevice SCPU
  -- >>> getDeviceType t
  -- CPU
  -- >>> t <- ones' $ SUncheckedDevice CPU
  -- >>> getDeviceType t
  -- CPU
  getDeviceType ::
    forall gradient layout dataType shape.
    -- | input
    Tensor gradient layout device dataType shape ->
    -- | compute device of the input tensor
    DeviceType Int16
  getDeviceType Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
tensor

instance SGetDevice 'UncheckedDevice where
  sGetDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout 'UncheckedDevice dataType shape
-> SDevice 'UncheckedDevice
sGetDevice Tensor gradient layout 'UncheckedDevice dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor gradient layout 'UncheckedDevice dataType shape
tensor) =
      case forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor gradient layout 'UncheckedDevice dataType shape
tensor) :: Int of
        Int
deviceIndex -> DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall deviceId. deviceId -> DeviceType deviceId
CUDA forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
deviceIndex
    | Bool
otherwise = DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall deviceId. DeviceType deviceId
CPU

instance SGetDevice ('Device 'CPU) where
  sGetDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout ('Device 'CPU) dataType shape
-> SDevice ('Device 'CPU)
sGetDevice Tensor gradient layout ('Device 'CPU) dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor gradient layout ('Device 'CPU) dataType shape
tensor) =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"The tensor should be on CPU but is on CUDA. "
          forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
    | Bool
otherwise = forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU

instance KnownNat deviceIndex => SGetDevice ('Device ('CUDA deviceIndex)) where
  sGetDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
-> SDevice ('Device ('CUDA deviceIndex))
sGetDevice Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
tensor) =
      case forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
tensor) :: Int of
        Int
deviceIndex
          | Int
deviceIndex forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @deviceIndex)) -> forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice forall (deviceId :: Nat).
KnownNat deviceId =>
SDeviceType ('CUDA deviceId)
SCUDA
          | Bool
otherwise ->
            forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
              String
"The tensor should be on CUDA device "
                forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @deviceIndex))
                forall a. Semigroup a => a -> a -> a
<> String
" but is on device "
                forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
deviceIndex
                forall a. Semigroup a => a -> a -> a
<> String
". "
                forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
    | Bool
otherwise =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"The tensor should be on CUDA but is on CPU. "
          forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

data DeviceError = DeviceError {DeviceError -> DeviceType Int16
deExpected :: DeviceType Int16, DeviceError -> DeviceType Int16
deActual :: DeviceType Int16}
  deriving stock (Int -> DeviceError -> ShowS
[DeviceError] -> ShowS
DeviceError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DeviceError] -> ShowS
$cshowList :: [DeviceError] -> ShowS
show :: DeviceError -> String
$cshow :: DeviceError -> String
showsPrec :: Int -> DeviceError -> ShowS
$cshowsPrec :: Int -> DeviceError -> ShowS
Show, Typeable)

instance Exception DeviceError where
  displayException :: DeviceError -> String
displayException DeviceError {DeviceType Int16
deActual :: DeviceType Int16
deExpected :: DeviceType Int16
deActual :: DeviceError -> DeviceType Int16
deExpected :: DeviceError -> DeviceType Int16
..} =
    String
"The tensor is not in the memory of the device `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DeviceType Int16
deExpected
      forall a. Semigroup a => a -> a -> a
<> String
"` but `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DeviceType Int16
deActual
      forall a. Semigroup a => a -> a -> a
<> String
"`."

-- | Checks whether or not the input tensor is in the memory of 'device'
-- and returns a statically annotated copy of it wrapped in a 'MonadThrow' 'm'.
--
-- For instance, if 'm' is 'Maybe', then the result will be wrapped in 'Just' if and only if the tensor is indeed on 'device'.
-- If it is not, then the result will be 'Nothing'.
--
-- In the REPL, 'm' will default to 'IO':
-- >>> t <- sOnes $ TensorSpec (SGradient SWithGradient) (SLayout SDense) (SUncheckedDevice CPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
-- >>> t' <- sCheckedDevice (SDevice SCPU) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
-- >>> t' <- sCheckedDevice (SDevice (SCUDA @0)) t
-- *** Exception: DeviceError {deExpected = CUDA 0, deActual = CPU}
sCheckedDevice ::
  forall device' m gradient layout device dataType shape.
  (SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
  -- | device
  SDevice device' ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout device' dataType shape)
sCheckedDevice :: forall (device' :: Device (DeviceType Nat)) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
SDevice device'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device' dataType shape)
sCheckedDevice SDevice device'
device' Tensor gradient layout device dataType shape
tensor =
  let actualDevice :: DeviceType Int16
actualDevice = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
tensor
      expectedDevice :: DeviceType Int16
expectedDevice = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDevice device'
device'
   in if DeviceType Int16
actualDevice forall a. Eq a => a -> a -> Bool
== DeviceType Int16
expectedDevice
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
        else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ DeviceType Int16 -> DeviceType Int16 -> DeviceError
DeviceError DeviceType Int16
expectedDevice DeviceType Int16
actualDevice

checkedDevice ::
  forall device' m gradient layout device dataType shape.
  (SingI device', SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout device' dataType shape)
checkedDevice :: forall (device' :: Device (DeviceType Nat)) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI device', SGetDevice device, MonadThrow m,
 Catch (device <+> device')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device' dataType shape)
checkedDevice = forall (device' :: Device (DeviceType Nat)) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
SDevice device'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device' dataType shape)
sCheckedDevice (forall {k} (a :: k). SingI a => Sing a
sing @device')

-- | Returns the input tensor but with 'UncheckedDevice' as device type annotation.
-- Any static information about the tensor's device is thus erased.
-- However, the tensor's underlying data structure is not changed.
--
-- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
-- >>> :type uncheckedDevice t
-- uncheckedDevice t
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        'UncheckedDevice
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
uncheckedDevice ::
  forall gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | tensor without checked device
  Tensor gradient layout 'UncheckedDevice dataType shape
uncheckedDevice :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout 'UncheckedDevice dataType shape
uncheckedDevice = coerce :: forall a b. Coercible a b => a -> b
coerce

-- | Returns a copy of the tensor converted to 'Bool'.
bool ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Bool' copy
  m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Bool) shape)
bool :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient) layout device ('DataType 'Bool) shape)
bool Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Bool

-- | Returns a copy of the tensor converted to 'UInt8'.
byte ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'UInt8' copy
  m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'UInt8) shape)
byte :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient)
        layout
        device
        ('DataType 'UInt8)
        shape)
byte Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
UInt8

-- | Returns a copy of the tensor converted to 'Int8'.
char ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Int8' copy
  m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int8) shape)
char :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient) layout device ('DataType 'Int8) shape)
char Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int8

-- | Returns a copy of the tensor converted to 'Int16'.
short ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Int16' copy
  m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int16) shape)
short :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient)
        layout
        device
        ('DataType 'Int16)
        shape)
short Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int16

-- | Returns a copy of the tensor converted to 'Int32'.
int ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Int32' copy
  m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int32) shape)
int :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient)
        layout
        device
        ('DataType 'Int32)
        shape)
int Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int32

-- | Returns a copy of the tensor converted to 'Int64'.
long ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Int64' copy
  m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape)
long :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient)
        layout
        device
        ('DataType 'Int64)
        shape)
long Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int64

-- | Returns a copy of the tensor converted to the 16-bit floating point format 'Half'.
half ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Half' copy
  m (Tensor gradient layout device ('DataType 'Half) shape)
half :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Half) shape)
half Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Half

-- | Returns a copy of the tensor converted to the 32-bit floating point format 'Float'.
float ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Float' copy
  m (Tensor gradient layout device ('DataType 'Float) shape)
float :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Float) shape)
float Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Float

-- | Returns a copy of the tensor converted to the 32-bit floating point format 'Double'.
double ::
  forall m gradient layout device dataType shape.
  MonadThrow m =>
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | 'Double' copy
  m (Tensor gradient layout device ('DataType 'Double) shape)
double :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Double) shape)
double Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Double

-- | Set the data type of a tensor to the specified data type.
sSetDataType ::
  forall m gradient layout device dataType dataType' shape.
  MonadThrow m =>
  SDataType dataType ->
  Tensor gradient layout device dataType' shape ->
  m (Tensor gradient layout device dataType shape)
sSetDataType :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (dataType' :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDataType dataType
-> Tensor gradient layout device dataType' shape
-> m (Tensor gradient layout device dataType shape)
sSetDataType SDataType dataType
dataType Tensor gradient layout device dataType' shape
input =
  case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDataType dataType
dataType) of
    DType
dType -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
      ForeignPtr TensorOptions
opts :: ForeignPtr ATen.TensorOptions <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options Tensor gradient layout device dataType' shape
input
      ForeignPtr TensorOptions
opts' :: ForeignPtr ATen.TensorOptions <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions
-> ScalarType -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_dtype_s ForeignPtr TensorOptions
opts DType
dType
      forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor gradient layout device dataType' shape
input ForeignPtr TensorOptions
opts' Bool
nonBlocking Bool
copy
      where
        nonBlocking :: Bool
nonBlocking = Bool
False
        copy :: Bool
copy = Bool
False

class SGetDataType (dataType :: DataType DType) where
  -- | Returns the gradually typed compute data type of the input tensor.
  --
  -- >>> sOnes' dataType = sOnes $ TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) dataType (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- sOnes' $ SDataType SFloat
  -- >>> sGetDataType t
  -- SDataType SFloat
  -- >>> t <- sOnes' $ SUncheckedDataType Float
  -- >>> sGetDataType t
  -- SUncheckedDataType Float
  sGetDataType ::
    forall gradient layout device shape.
    -- | input
    Tensor gradient layout device dataType shape ->
    -- | data type of the input tensor
    SDataType dataType

  -- | Returns the untyped compute data type of the input tensor.
  --
  -- >>> sOnes' dataType = sOnes $ TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) dataType (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
  -- >>> t <- sOnes' $ SDataType SFloat
  -- >>> getDType t
  -- Float
  -- >>> t <- sOnes' $ SUncheckedDataType Float
  -- >>> getDType t
  -- Float
  getDType ::
    forall gradient layout device shape.
    -- | input
    Tensor gradient layout device dataType shape ->
    -- | data type of the input tensor
    DType
  getDType Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (dataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDataType dataType =>
Tensor gradient layout device dataType shape -> SDataType dataType
sGetDataType Tensor gradient layout device dataType shape
tensor

instance SGetDataType 'UncheckedDataType where
  sGetDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device 'UncheckedDataType shape
-> SDataType 'UncheckedDataType
sGetDataType Tensor gradient layout device 'UncheckedDataType shape
tensor = DType -> SDataType 'UncheckedDataType
SUncheckedDataType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor gradient layout device 'UncheckedDataType shape
tensor

instance SingI dType => SGetDataType ('DataType dType) where
  sGetDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device ('DataType dType) shape
-> SDataType ('DataType dType)
sGetDataType Tensor gradient layout device ('DataType dType) shape
tensor
    | forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor gradient layout device ('DataType dType) shape
tensor) forall a. Eq a => a -> a -> Bool
== forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing (forall {k} (a :: k). SingI a => Sing a
sing @dType) = forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dType
    | Bool
otherwise =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"The tensor should have data type "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dType)
          forall a. Semigroup a => a -> a -> a
<> String
" but hasn't. "
          forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

data DataTypeError = DataTypeError {DataTypeError -> DType
dtExpected :: DType, DataTypeError -> DType
dtActual :: DType}
  deriving stock (Int -> DataTypeError -> ShowS
[DataTypeError] -> ShowS
DataTypeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DataTypeError] -> ShowS
$cshowList :: [DataTypeError] -> ShowS
show :: DataTypeError -> String
$cshow :: DataTypeError -> String
showsPrec :: Int -> DataTypeError -> ShowS
$cshowsPrec :: Int -> DataTypeError -> ShowS
Show, Typeable)

instance Exception DataTypeError where
  displayException :: DataTypeError -> String
displayException DataTypeError {DType
dtActual :: DType
dtExpected :: DType
dtActual :: DataTypeError -> DType
dtExpected :: DataTypeError -> DType
..} =
    String
"The tensor does not have the data type `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DType
dtExpected
      forall a. Semigroup a => a -> a -> a
<> String
"` but `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DType
dtActual
      forall a. Semigroup a => a -> a -> a
<> String
"`."

-- | Checks whether or not the input tensor has the data type 'dataType'
-- and returns a statically annotated copy of it wrapped in a 'MonadThrow' 'm'.
--
-- For instance, if 'm' is 'Maybe', then the result will be wrapped in 'Just' if and only if the tensor has indeed the data type 'dataType'.
-- If it does not have it, then the result will be 'Nothing'.
--
-- In the REPL, 'm' will default to 'IO':
-- >>> t <- sOnes $ TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SUncheckedDataType Float) (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil)
-- >>> t' <- checkedDataType @('DataType 'Float) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
-- >>> t' <- checkedDataType @('DataType 'Double) t
-- *** Exception: DataTypeError {dtExpected = Double, dtActual = Float}
sCheckedDataType ::
  forall dataType' m gradient layout device dataType shape.
  (SGetDataType dataType, MonadThrow m, Catch (dataType <+> dataType')) =>
  -- | data type
  SDataType dataType' ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout device dataType' shape)
sCheckedDataType :: forall (dataType' :: DataType DType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDataType dataType, MonadThrow m,
 Catch (dataType <+> dataType')) =>
SDataType dataType'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType' shape)
sCheckedDataType SDataType dataType'
dataType' Tensor gradient layout device dataType shape
tensor =
  let actualDataType :: DType
actualDataType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (dataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDataType dataType =>
Tensor gradient layout device dataType shape -> SDataType dataType
sGetDataType Tensor gradient layout device dataType shape
tensor
      expectedDataType :: DType
expectedDataType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDataType dataType'
dataType'
   in if DType
actualDataType forall a. Eq a => a -> a -> Bool
== DType
expectedDataType
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
        else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ DType -> DType -> DataTypeError
DataTypeError DType
expectedDataType DType
actualDataType

checkedDataType ::
  forall dataType' m gradient layout device dataType shape.
  (SingI dataType', SGetDataType dataType, MonadThrow m, Catch (dataType <+> dataType')) =>
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout device dataType' shape)
checkedDataType :: forall (dataType' :: DataType DType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI dataType', SGetDataType dataType, MonadThrow m,
 Catch (dataType <+> dataType')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType' shape)
checkedDataType = forall (dataType' :: DataType DType) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDataType dataType, MonadThrow m,
 Catch (dataType <+> dataType')) =>
SDataType dataType'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType' shape)
sCheckedDataType (forall {k} (a :: k). SingI a => Sing a
sing @dataType')

-- | Returns the input tensor but with 'UncheckedDataType' as data-type type annotation.
-- Any static information about the tensor's data type is thus erased.
-- However, the tensor's underlying data structure is not changed.
--
-- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
-- >>> :type uncheckedDataType t
-- uncheckedDataType t
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        'UncheckedDataType
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
uncheckedDataType ::
  forall gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | tensor without checked data type
  Tensor gradient layout device 'UncheckedDataType shape
uncheckedDataType :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device 'UncheckedDataType shape
uncheckedDataType = coerce :: forall a b. Coercible a b => a -> b
coerce

class SGetShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]) where
  -- | Returns the gradually typed shape of the input tensor.
  --
  -- >>> sOnes' = sOnes . TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat)
  -- >>> t <- sOnes' . SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil
  -- >>> sGetShape t
  -- SShape (SCons (SDim {sDimName = SName, sDimSize = SSize}) (SCons (SDim {sDimName = SName, sDimSize = SSize}) SNil))
  -- >>> t <- sOnes' . SUncheckedShape $ [Dim "batch" 32, Dim "feature" 8]
  -- >>> sGetShape t
  -- SUncheckedShape [Dim {dimName = "batch", dimSize = 32},Dim {dimName = "feature", dimSize = 8}]
  -- >>> t <- sOnes' . SShape $ SUncheckedName "batch" :&: SUncheckedSize 32 :|: SUncheckedName "feature" :&: SSize @32 :|: SNil
  -- >>> sGetShape t
  -- SShape (SCons (SDim {sDimName = SUncheckedName "batch", sDimSize = SUncheckedSize 32}) (SCons (SDim {sDimName = SUncheckedName "feature", sDimSize = SSize}) SNil))
  -- >>> t <- sOnes' . SShape $ SName @"batch" :&: SUncheckedSize 32 :|: SName @"feature" :&: SUncheckedSize 8 :|: SNil
  -- >>> sGetShape t
  -- SShape (SCons (SDim {sDimName = SName, sDimSize = SUncheckedSize 32}) (SCons (SDim {sDimName = SName, sDimSize = SUncheckedSize 8}) SNil))
  sGetShape ::
    forall gradient layout device dataType.
    Tensor gradient layout device dataType shape ->
    SShape shape

  -- | Returns the untyped shape of the input tensor.
  --
  -- >>> sOnes' = sOnes . TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat)
  -- >>> t <- sOnes' . SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil
  -- >>> getDims t
  -- [Dim {dimName = "batch", dimSize = 32},Dim {dimName = "feature", dimSize = 8}]
  -- >>> t <- sOnes' . SUncheckedShape $ [Dim "batch" 32, Dim "feature" 8]
  -- >>> getDims t
  -- [Dim {dimName = "batch", dimSize = 32},Dim {dimName = "feature", dimSize = 8}]
  -- >>> t <- sOnes' . SShape $ SUncheckedName "batch" :&: SUncheckedSize 32 :|: SUncheckedName "feature" :&: SSize @32 :|: SNil
  -- >>> getDims t
  -- [Dim {dimName = "batch", dimSize = 32},Dim {dimName = "feature", dimSize = 32}]
  -- >>> t <- sOnes' . SShape $ SName @"batch" :&: SUncheckedSize 32 :|: SName @"feature" :&: SUncheckedSize 8 :|: SNil
  -- >>> getDims t
  -- [Dim {dimName = "batch", dimSize = 32},Dim {dimName = "feature", dimSize = 8}]
  getDims ::
    forall gradient layout device dataType.
    Tensor gradient layout device dataType shape ->
    [Dim String Integer]
  getDims Tensor gradient layout device dataType shape
tensor = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. IsChecked a -> a
forgetIsChecked forall a. IsChecked a -> a
forgetIsChecked) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
tensor

instance SGetShape 'UncheckedShape where
  sGetShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
Tensor gradient layout device dataType 'UncheckedShape
-> SShape 'UncheckedShape
sGetShape Tensor gradient layout device dataType 'UncheckedShape
tensor = [Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    [Integer]
sizes <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes Tensor gradient layout device dataType 'UncheckedShape
tensor
    forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM
      (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_has_names Tensor gradient layout device dataType 'UncheckedShape
tensor)
      ( do
          [String]
names <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr DimnameList)
ATen.tensor_names Tensor gradient layout device dataType 'UncheckedShape
tensor
          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall name size. name -> size -> Dim name size
Dim [String]
names [Integer]
sizes
      )
      (forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall name size. name -> size -> Dim name size
Dim String
"*" forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Integer]
sizes)

instance SGetDims dims => SGetShape ('Shape dims) where
  sGetShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
Tensor gradient layout device dataType ('Shape dims)
-> SShape ('Shape dims)
sGetShape Tensor gradient layout device dataType ('Shape dims)
tensor =
    let sizes :: [Integer]
sizes =
          forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM
              ((forall a. Ord a => a -> a -> Bool
> (Int
0 :: Int)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim Tensor gradient layout device dataType ('Shape dims)
tensor)
              (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes Tensor gradient layout device dataType ('Shape dims)
tensor)
              (forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
        names :: [String]
names =
          forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM
              (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_has_names Tensor gradient layout device dataType ('Shape dims)
tensor)
              (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr DimnameList)
ATen.tensor_names Tensor gradient layout device dataType ('Shape dims)
tensor)
              (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const String
"*") [Integer]
sizes)
     in forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SGetDims dims =>
[String] -> [Integer] -> SList dims
sGetDims [String]
names [Integer]
sizes

class SGetDims (dims :: [Dim (Name Symbol) (Size Nat)]) where
  sGetDims :: [String] -> [Integer] -> SList dims

dimsError :: forall a. a
dimsError :: forall a. a
dimsError = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"The numbers of compile- and runtime dimensions are not the same. " forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

dimNameError :: forall a. String -> String -> a
dimNameError :: forall a. String -> String -> a
dimNameError String
name String
name' =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"The compile- and runtime dimension names are not the same, '"
      forall a. Semigroup a => a -> a -> a
<> String
name
      forall a. Semigroup a => a -> a -> a
<> String
"' != '"
      forall a. Semigroup a => a -> a -> a
<> String
name'
      forall a. Semigroup a => a -> a -> a
<> String
"'. "
      forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

dimSizeError :: forall a b. Show a => a -> a -> b
dimSizeError :: forall a b. Show a => a -> a -> b
dimSizeError a
size a
size' =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"The compile- and runtime dimension sizes are not the same, '"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size
      forall a. Semigroup a => a -> a -> a
<> String
"' != '"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size'
      forall a. Semigroup a => a -> a -> a
<> String
"'. "
      forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

dimNameSizeError :: forall a b. Show a => String -> String -> a -> a -> b
dimNameSizeError :: forall a b. Show a => String -> String -> a -> a -> b
dimNameSizeError String
name String
name' a
size a
size' =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"The compile- and runtime dimension names and sizes are not the same, '"
      forall a. Semigroup a => a -> a -> a
<> String
name
      forall a. Semigroup a => a -> a -> a
<> String
"' != '"
      forall a. Semigroup a => a -> a -> a
<> String
name'
      forall a. Semigroup a => a -> a -> a
<> String
"' and '"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size
      forall a. Semigroup a => a -> a -> a
<> String
"' != '"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size'
      forall a. Semigroup a => a -> a -> a
<> String
"'. "
      forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg

instance SGetDims '[] where
  sGetDims :: [String] -> [Integer] -> SList '[]
sGetDims [] [] = forall a. SList '[]
SNil
  sGetDims [String]
_ [Integer]
_ = forall a. a
dimsError

instance (SGetDim dim, SGetDims dims) => SGetDims (dim : dims) where
  sGetDims :: [String] -> [Integer] -> SList (dim : dims)
sGetDims (String
name : [String]
names) (Integer
size : [Integer]
sizes) = forall (dim :: Dim (Name Symbol) (Size Nat)).
SGetDim dim =>
String -> Integer -> SDim dim
sGetDim String
name Integer
size forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SGetDims dims =>
[String] -> [Integer] -> SList dims
sGetDims [String]
names [Integer]
sizes
  sGetDims [String]
_ [Integer]
_ = forall a. a
dimsError

class SGetDim (dim :: Dim (Name Symbol) (Size Nat)) where
  sGetDim :: String -> Integer -> SDim dim

instance SGetDim ('Dim 'UncheckedName 'UncheckedSize) where
  sGetDim :: String -> Integer -> SDim ('Dim 'UncheckedName 'UncheckedSize)
sGetDim String
name Integer
size = forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (String -> SName 'UncheckedName
SUncheckedName String
name) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)

instance KnownSymbol name => SGetDim ('Dim ('Name name) 'UncheckedSize) where
  sGetDim :: String -> Integer -> SDim ('Dim ('Name name) 'UncheckedSize)
sGetDim String
name Integer
size = case forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name of
    String
name'
      | String
name forall a. Eq a => a -> a -> Bool
== String
name' -> forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @name) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)
      | Bool
otherwise -> forall a. String -> String -> a
dimNameError String
name String
name'

instance KnownNat size => SGetDim ('Dim 'UncheckedName ('Size size)) where
  sGetDim :: String -> Integer -> SDim ('Dim 'UncheckedName ('Size size))
sGetDim String
name Integer
size = case forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @size of
    Integer
size'
      | Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' -> forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (String -> SName 'UncheckedName
SUncheckedName String
name) (forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @size)
      | Bool
otherwise -> forall a b. Show a => a -> a -> b
dimSizeError Integer
size Integer
size'

instance (KnownSymbol name, KnownNat size) => SGetDim ('Dim ('Name name) ('Size size)) where
  sGetDim :: String -> Integer -> SDim ('Dim ('Name name) ('Size size))
sGetDim String
name Integer
size = case (forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name, forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @size) of
    (String
name', Integer
size')
      | String
name forall a. Eq a => a -> a -> Bool
== String
name' Bool -> Bool -> Bool
&& Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' -> forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @name) (forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @size)
      | String
name forall a. Eq a => a -> a -> Bool
/= String
name' Bool -> Bool -> Bool
&& Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' -> forall a. String -> String -> a
dimNameError String
name String
name'
      | String
name forall a. Eq a => a -> a -> Bool
== String
name' Bool -> Bool -> Bool
&& Integer
size forall a. Eq a => a -> a -> Bool
/= Integer
size' -> forall a b. Show a => a -> a -> b
dimSizeError Integer
size Integer
size'
      | Bool
otherwise -> forall a b. Show a => String -> String -> a -> a -> b
dimNameSizeError String
name String
name' Integer
size Integer
size'

data ShapeError = ShapeError {ShapeError -> [Dim String Integer]
seExpected :: [Dim String Integer], ShapeError -> [Dim String Integer]
seActual :: [Dim String Integer]}
  deriving stock (Int -> ShapeError -> ShowS
[ShapeError] -> ShowS
ShapeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ShapeError] -> ShowS
$cshowList :: [ShapeError] -> ShowS
show :: ShapeError -> String
$cshow :: ShapeError -> String
showsPrec :: Int -> ShapeError -> ShowS
$cshowsPrec :: Int -> ShapeError -> ShowS
Show)

instance Exception ShapeError where
  displayException :: ShapeError -> String
displayException ShapeError {[Dim String Integer]
seActual :: [Dim String Integer]
seExpected :: [Dim String Integer]
seActual :: ShapeError -> [Dim String Integer]
seExpected :: ShapeError -> [Dim String Integer]
..} =
    String
"The tensor does not have the shape `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Dim String Integer]
seExpected
      forall a. Semigroup a => a -> a -> a
<> String
"` but `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Dim String Integer]
seActual
      forall a. Semigroup a => a -> a -> a
<> String
"`."

-- | Checks whether or not the input tensor has the shape 'shape'
-- and returns a statically annotated copy of it wrapped in a 'MonadThrow' 'm'.
--
-- For instance, if 'm' is 'Maybe', then the result will be wrapped in 'Just' if and only if the tensor has indeed the shape 'shape'.
-- If it is not, then the result will be 'Nothing'.
--
-- In the REPL, 'm' will default to 'IO':
-- >>> t <- sOnes $ TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SUncheckedShape [Dim "batch" 32, Dim "feature" 8])
-- >>> t' <- sCheckedShape (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SSize @8 :|: SNil) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim ('Name "feature") ('Size 8)])
-- >>> t' <- sCheckedShape (SShape $ SUncheckedName "batch" :&: SSize @32 :|: SName @"feature" :&: SUncheckedSize 8 :|: SNil) t
-- >>> :type t'
-- t'
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim 'UncheckedName ('Size 32),
--              'Dim ('Name "feature") 'UncheckedSize])
-- >>> t' <- sCheckedShape (SShape $ SName @"batch" :&: SSize @32 :|: SName @"feature" :&: SUncheckedSize 32 :|: SNil) t
-- *** Exception: ShapeError {seExpected = [Dim {dimName = "batch", dimSize = 32},Dim {dimName = "feature", dimSize = 32}], seActual = [Dim {dimName = "batch", dimSize = 32},Dim {dimName = "feature", dimSize = 8}]}
sCheckedShape ::
  forall shape' m gradient layout device dataType shape.
  (SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
  -- | shape
  SShape shape' ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout device dataType shape')
sCheckedShape :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape SShape shape'
shape' Tensor gradient layout device dataType shape
tensor =
  let f :: Sing a -> [Dim String Integer]
f = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing
      actualShape :: [Dim String Integer]
actualShape = forall {a :: Shape [Dim (Name Symbol) (Size Nat)]}.
Sing a -> [Dim String Integer]
f forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
tensor
      expectedShape :: [Dim String Integer]
expectedShape = forall {a :: Shape [Dim (Name Symbol) (Size Nat)]}.
Sing a -> [Dim String Integer]
f SShape shape'
shape'
   in if [Dim String Integer]
actualShape forall a. Eq a => a -> a -> Bool
== [Dim String Integer]
expectedShape
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
        else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Dim String Integer] -> [Dim String Integer] -> ShapeError
ShapeError [Dim String Integer]
expectedShape [Dim String Integer]
actualShape

checkedShape ::
  forall shape' m gradient layout device dataType shape.
  (SingI shape', SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | annotated output tensor wrapped in 'm'
  m (Tensor gradient layout device dataType shape')
checkedShape :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI shape', SGetShape shape, MonadThrow m,
 Catch (shape <+> shape')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
checkedShape = forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall {k} (a :: k). SingI a => Sing a
sing @shape')

-- | Returns the input tensor but with the selected dimension replaces with 'UncheckedDim' as dimension type annotation.
-- The static information about the selected tensor dimension is thus erased.
-- However, the tensor's underlying data structure is not changed.
--
-- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
-- >>> :type uncheckedDim @('SelectDim ('ByName "batch")) t
-- uncheckedDim @('SelectDim ('ByName "batch")) t
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim 'UncheckedName 'UncheckedSize,
--              'Dim ('Name "feature") ('Size 8)])
-- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
-- >>> :type uncheckedDim @('SelectDim ('ByIndex 1)) t
-- uncheckedDim @('SelectDim ('ByIndex 1)) t
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "batch") ('Size 32),
--              'Dim 'UncheckedName 'UncheckedSize])
uncheckedDim ::
  forall selectDim gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | tensor with the selected dimensions unchecked
  Tensor gradient layout device dataType (ReplaceDimF selectDim shape ('Dim 'UncheckedName 'UncheckedSize))
uncheckedDim :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor
     gradient
     layout
     device
     dataType
     (ReplaceDimF selectDim shape ('Dim 'UncheckedName 'UncheckedSize))
uncheckedDim = coerce :: forall a b. Coercible a b => a -> b
coerce

-- | Returns the input tensor but with 'UncheckedShape' as shape type annotation.
-- Any static information about the tensor's shape is thus erased.
-- However, the tensor's underlying data structure is not changed.
--
-- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
-- >>> :type uncheckedShape t
-- uncheckedShape t
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        'UncheckedShape
uncheckedShape ::
  forall gradient layout device dataType shape.
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | tensor without checked shape
  Tensor gradient layout device dataType 'UncheckedShape
uncheckedShape :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType 'UncheckedShape
uncheckedShape = coerce :: forall a b. Coercible a b => a -> b
coerce

gitHubErrorMsg :: String
gitHubErrorMsg :: String
gitHubErrorMsg = String
"Please open a ticket on GitHub."

isContiguous ::
  Tensor gradient layout device dataType shape ->
  Bool
isContiguous :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape -> Bool
isContiguous Tensor gradient layout device dataType shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_contiguous Tensor gradient layout device dataType shape
t

contiguous ::
  Tensor gradient layout device dataType shape ->
  Tensor gradient layout device dataType shape
contiguous :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
contiguous Tensor gradient layout device dataType shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_contiguous Tensor gradient layout device dataType shape
t

withTensor :: Tensor gradient layout device dataType shape -> (Ptr () -> IO a) -> IO a
withTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) a.
Tensor gradient layout device dataType shape
-> (Ptr () -> IO a) -> IO a
withTensor Tensor gradient layout device dataType shape
t Ptr () -> IO a
fn =
  let contiguousTensor :: Tensor gradient layout device dataType shape
contiguousTensor = if forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape -> Bool
isContiguous Tensor gradient layout device dataType shape
t then Tensor gradient layout device dataType shape
t else forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
contiguous Tensor gradient layout device dataType shape
t
   in forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor gradient layout device dataType shape
contiguousTensor forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
ct -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
ct forall a b. (a -> b) -> a -> b
$ Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Ptr () -> IO a
fn

class TensorLikeRaw a where
  -- | Guesses outer dim.
  --
  -- >>> guessDim @[[Int]] $ pure [[1, 2], [3, 4], [5, 6]]
  -- Just 3
  guessDim ::
    -- | value
    -- 'Nothing' if the data type wrapping 'a' is empty.
    Maybe a ->
    -- | dimension
    -- 'Nothing' if 'a' is a scalar.
    Maybe Int

  -- | Guesses inner dims.
  --
  -- >>> guessInnerDims @[[Int]] $ pure [[1, 2], [3, 4], [5, 6]]
  -- [2]
  guessInnerDims ::
    MonadThrow m =>
    -- | value
    -- 'Nothing' if the data type wrapping 'a' is empty.
    Maybe a ->
    -- | inner dimensions
    m [Int]

  -- | Reads a value from a tensor.
  tensorPeekElemOff ::
    -- | pointer to tensor
    Ptr () ->
    -- | offset
    Int ->
    -- | tensor dimensions
    [Int] ->
    -- | value
    IO a

  -- | Writes a value to a tensor.
  tensorPokeElemOff ::
    -- | pointer to tensor
    Ptr () ->
    -- | offset
    Int ->
    -- | tensor dimensions
    [Int] ->
    -- | value
    a ->
    IO ()

-- | Guesses dims: concatenates 'guessDim' with 'guessInnerDims'.
--
-- >>> guessDims @[[Int]] $ pure [[1, 2], [3, 4], [5, 6]]
-- [3,2]
guessDims :: forall a m. (TensorLikeRaw a, MonadThrow m) => Maybe a -> m [Int]
guessDims :: forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x = ([Int]
outerDim forall a. Semigroup a => a -> a -> a
<>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessInnerDims Maybe a
x
  where
    outerDim :: [Int]
outerDim = forall a. Maybe a -> [a]
maybeToList forall a b. (a -> b) -> a -> b
$ forall a. TensorLikeRaw a => Maybe a -> Maybe Int
guessDim Maybe a
x

unexpectedDimsError :: forall a m b. (TensorLikeRaw a, MonadThrow m) => [Int] -> Maybe a -> m b
unexpectedDimsError :: forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' Maybe a
x = do
  [Int]
expected <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Expected shape to be " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
expected forall a. Semigroup a => a -> a -> a
<> String
" got: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
dims'

class TensorLike a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]) | a -> dims, a -> dType where
  -- | Creates a tensor from a 'TensorLike' value.
  --
  -- >>> t <- sToTensor (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) ([(1, 2), (3, 4), (5, 6)] :: [(Int, Int)])
  -- >>> t
  -- Tensor Int64 [3,2] [[ 1,  2],
  --                     [ 3,  4],
  --                     [ 5,  6]]
  -- >>> :type t
  -- t :: Tensor
  --        ('Gradient 'WithoutGradient)
  --        ('Layout 'Dense)
  --        ('Device 'CPU)
  --        ('DataType 'Int64)
  --        ('Shape
  --           '[ 'Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") ('Size 2)])
  sToTensor ::
    forall gradient layout device m.
    MonadThrow m =>
    SGradient gradient ->
    SLayout layout ->
    SDevice device ->
    a ->
    m (Tensor gradient layout device ('DataType dType) ('Shape dims))

  -- | Creates a 'TensorLike' from a tensor.
  fromTensor ::
    forall gradient layout device.
    Tensor gradient layout device ('DataType dType) ('Shape dims) ->
    a

-- | Non-singleton version of 'sToTensor'.
toTensor ::
  forall gradient layout device a dType dims m.
  ( TensorLike a dType dims,
    SingI gradient,
    SingI layout,
    SingI device,
    MonadThrow m
  ) =>
  a ->
  m (Tensor gradient layout device ('DataType dType) ('Shape dims))
toTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, SingI gradient, SingI layout,
 SingI device, MonadThrow m) =>
a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
toTensor = forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensor (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device)

sToTensorRaw ::
  forall gradient layout device a dType dims m.
  (TensorLike a dType dims, TensorLikeRaw a, SingI dType, MonadThrow m) =>
  SGradient gradient ->
  SLayout layout ->
  SDevice device ->
  a ->
  m (Tensor gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw SGradient gradient
gradient' SLayout layout
layout SDevice device
device a
x = do
  [Int]
dims' <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
      Tensor gradient layout Any ('DataType dType) ('Shape dims)
t <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.empty_lo [Int]
dims' TensorOptions
opts
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) a.
Tensor gradient layout device dataType shape
-> (Ptr () -> IO a) -> IO a
withTensor Tensor gradient layout Any ('DataType dType) ('Shape dims)
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr ->
        forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
0 [Int]
dims' a
x
      forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDevice device
-> Tensor gradient layout device' dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetDevice SDevice device
device Tensor gradient layout Any ('DataType dType) ('Shape dims)
t
  where
    opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient' SLayout layout
layout (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) (forall {k} (a :: k). SingI a => Sing a
sing @('DataType dType))

fromTensorRaw ::
  forall gradient layout device a dType dims.
  (TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
  Tensor gradient layout device ('DataType dType) ('Shape dims) ->
  a
fromTensorRaw :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw Tensor gradient layout device ('DataType dType) ('Shape dims)
t =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDevice device
-> Tensor gradient layout device' dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetDevice (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) Tensor gradient layout device ('DataType dType) ('Shape dims)
t
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) a.
Tensor gradient layout device dataType shape
-> (Ptr () -> IO a) -> IO a
withTensor (\Ptr ()
ptr -> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
0 forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient layout device ('DataType dType) ('Shape dims)
t)

instance TensorLike Bool 'Bool '[] where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Bool
-> m (Tensor gradient layout device ('DataType 'Bool) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Bool) ('Shape '[])
-> Bool
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance TensorLikeRaw Bool where
  guessDim :: Maybe Bool -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Bool -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Bool
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff @Word8 (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall a. Eq a => a -> a -> Bool
== Word8
1)
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Bool [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Bool -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Bool
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff @Word8 (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset (forall a. Num a => Bool -> a
fromBool Bool
x)
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Bool
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
x

instance TensorLike Int 'Int64 '[] where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Int
-> m (Tensor
        gradient layout device ('DataType 'Int64) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Int64) ('Shape '[])
-> Int
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance TensorLikeRaw Int where
  guessDim :: Maybe Int -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Int -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Int
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Int [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Int -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Int
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset Int
x
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Int
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x

instance TensorLike Float 'Float '[] where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Float
-> m (Tensor
        gradient layout device ('DataType 'Float) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Float) ('Shape '[])
-> Float
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance TensorLikeRaw Float where
  guessDim :: Maybe Float -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Float -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Float
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Float [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Float -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Float
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset Float
x
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Float
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Float
x

instance TensorLike Double 'Double '[] where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Double
-> m (Tensor
        gradient layout device ('DataType 'Double) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Double) ('Shape '[])
-> Double
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance TensorLikeRaw Double where
  guessDim :: Maybe Double -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Double -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Double
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Double [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Double -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Double
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset Double
x
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Double
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
x

data DimMismatchError = DimMismatchError {DimMismatchError -> [Int]
dmeFirst :: [Int], DimMismatchError -> [Int]
dmeOther :: [Int]}
  deriving (Int -> DimMismatchError -> ShowS
[DimMismatchError] -> ShowS
DimMismatchError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DimMismatchError] -> ShowS
$cshowList :: [DimMismatchError] -> ShowS
show :: DimMismatchError -> String
$cshow :: DimMismatchError -> String
showsPrec :: Int -> DimMismatchError -> ShowS
$cshowsPrec :: Int -> DimMismatchError -> ShowS
Show, DimMismatchError -> DimMismatchError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DimMismatchError -> DimMismatchError -> Bool
$c/= :: DimMismatchError -> DimMismatchError -> Bool
== :: DimMismatchError -> DimMismatchError -> Bool
$c== :: DimMismatchError -> DimMismatchError -> Bool
Eq)

instance Exception DimMismatchError where
  displayException :: DimMismatchError -> String
displayException DimMismatchError {[Int]
dmeOther :: [Int]
dmeFirst :: [Int]
dmeOther :: DimMismatchError -> [Int]
dmeFirst :: DimMismatchError -> [Int]
..} =
    String
"When converting to a tensor, all elements on the same dimension must have the same shape, "
      forall a. Semigroup a => a -> a -> a
<> String
"but the first element has shape "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
dmeFirst
      forall a. Semigroup a => a -> a -> a
<> String
" while another element has shape "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
dmeOther
      forall a. Semigroup a => a -> a -> a
<> String
"."

checkDims :: MonadThrow m => [Int] -> [Int] -> m ()
checkDims :: forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
firstDims [Int]
otherDims = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Int]
firstDims forall a. Eq a => a -> a -> Bool
/= [Int]
otherDims) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> DimMismatchError
DimMismatchError [Int]
firstDims [Int]
otherDims

instance
  ( TensorLike a dType dims,
    TensorLike b dType dims',
    TensorLikeRaw a,
    TensorLikeRaw b,
    SingI dType,
    SGetDims dimsOut,
    'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape (dims <+> dims')) ('Dim ('Name "*") ('Size 2))
  ) =>
  TensorLike (a, b) dType dimsOut
  where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> (a, b)
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> (a, b)
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance (TensorLikeRaw a, TensorLikeRaw b) => TensorLikeRaw (a, b) where
  guessDim :: Maybe (a, b) -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
2

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (a, b) -> m [Int]
guessInnerDims (forall (f :: * -> *) a b. Functor f => f (a, b) -> (f a, f b)
unzip -> (Maybe a
x, Maybe b
y)) = do
    [Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x
    [Int]
yDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe b
y
    forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims [Int]
yDims
    forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (a, b)
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
2 : [Int]
innerDims) =
    (,)
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
innerDims
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @(a, b) [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> (a, b) -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
2 : [Int]
innerDims) (a
x, b
y) = do
    forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
innerDims a
x
    forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims b
y
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' (a, b)
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b)
x

instance
  ( TensorLike a dType dims,
    TensorLike b dType dims',
    TensorLike c dType dims',
    TensorLikeRaw a,
    TensorLikeRaw b,
    TensorLikeRaw c,
    SingI dType,
    SGetDims dimsOut,
    'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape (dims <+> dims')) ('Dim ('Name "*") ('Size 3))
  ) =>
  TensorLike (a, b, c) dType dimsOut
  where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> (a, b, c)
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> (a, b, c)
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

unzip3 :: Functor f => f (a, b, c) -> (f a, f b, f c)
unzip3 :: forall (f :: * -> *) a b c.
Functor f =>
f (a, b, c) -> (f a, f b, f c)
unzip3 f (a, b, c)
xyz =
  ( (\(a
x, b
_, c
_) -> a
x) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (a, b, c)
xyz,
    (\(a
_, b
y, c
_) -> b
y) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (a, b, c)
xyz,
    (\(a
_, b
_, c
z) -> c
z) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (a, b, c)
xyz
  )

instance (TensorLikeRaw a, TensorLikeRaw b, TensorLikeRaw c) => TensorLikeRaw (a, b, c) where
  guessDim :: Maybe (a, b, c) -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
2

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (a, b, c) -> m [Int]
guessInnerDims (forall (f :: * -> *) a b c.
Functor f =>
f (a, b, c) -> (f a, f b, f c)
unzip3 -> (Maybe a
x, Maybe b
y, Maybe c
z)) = do
    [Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x
    [Int]
yDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe b
y
    [Int]
zDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe c
z
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims) [[Int]
yDims, [Int]
zDims]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (a, b, c)
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
3 : [Int]
innerDims) =
    (,,)
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
innerDims
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @(a, b) [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> (a, b, c) -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
3 : [Int]
innerDims) (a
x, b
y, c
z) = do
    forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
innerDims a
x
    forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims b
y
    forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims c
z
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' (a, b, c)
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, c)
x

instance
  ( TensorLike a dType dims,
    TensorLikeRaw a,
    SingI dType,
    SGetDims dimsOut,
    'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape dims) ('Dim ('Name "*") 'UncheckedSize)
  ) =>
  TensorLike [a] dType dimsOut
  where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> [a]
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> [a]
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance TensorLikeRaw a => TensorLikeRaw [a] where
  guessDim :: Maybe [a] -> Maybe Int
guessDim = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 forall (t :: * -> *) a. Foldable t => t a -> Int
length

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe [a] -> m [Int]
guessInnerDims =
    (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. [a] -> Maybe (NonEmpty a)
nonEmpty) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
      Maybe (NonEmpty a)
Nothing -> forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims @a forall (f :: * -> *) a. Alternative f => f a
empty
      Just (a
x :| [a]
xs) -> do
        [Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
        forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure) [a]
xs
        forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO [a]
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) =
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. Int
d forall a. Num a => a -> a -> a
- Int
1] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @[a] [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> [a] -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) [a]
xs =
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. Int
d forall a. Num a => a -> a -> a
- Int
1] [a]
xs) forall a b. (a -> b) -> a -> b
$ \(Int
i, a
x) ->
      forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims a
x
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' [a]
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
x

instance
  ( TensorLike a dType dims,
    TensorLikeRaw a,
    SingI dType,
    SGetDims dimsOut,
    'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape dims) ('Dim ('Name "*") 'UncheckedSize)
  ) =>
  TensorLike (V.Vector a) dType dimsOut
  where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Vector a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> Vector a
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance
  TensorLikeRaw a =>
  TensorLikeRaw (V.Vector a)
  where
  guessDim :: Maybe (Vector a) -> Maybe Int
guessDim = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 forall (t :: * -> *) a. Foldable t => t a -> Int
length

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (Vector a) -> m [Int]
guessInnerDims =
    (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Vector a -> Maybe (a, Vector a)
V.uncons) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
      Maybe (a, Vector a)
Nothing -> forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims @a forall (f :: * -> *) a. Alternative f => f a
empty
      Just (a
x, Vector a
xs) -> do
        [Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
        forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure) Vector a
xs
        forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (Vector a)
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) =
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a. Enum a => a -> a -> Vector a
V.enumFromTo Int
0 (Int
d forall a. Num a => a -> a -> a
- Int
1)) forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @(V.Vector a) [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Vector a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) Vector a
xs = do
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. Vector a -> Vector b -> Vector (a, b)
V.zip (forall a. Enum a => a -> a -> Vector a
V.enumFromTo Int
0 (Int
d forall a. Num a => a -> a -> a
- Int
1)) Vector a
xs) forall a b. (a -> b) -> a -> b
$ \(Int
i, a
x) ->
      forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims a
x
    where
      width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
  tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Vector a
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
x

instance
  ( KnownNat n,
    TensorLike a dType dims,
    TensorLikeRaw a,
    SingI dType,
    SGetDims dimsOut,
    'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape dims) ('Dim ('Name "*") ('Size n))
  ) =>
  TensorLike (SV.Vector n a) dType dimsOut
  where
  sToTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Vector n a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
 MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
  fromTensor :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> Vector n a
fromTensor = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
       (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw

instance
  ( KnownNat n,
    TensorLikeRaw a
  ) =>
  TensorLikeRaw (SV.Vector n a)
  where
  guessDim :: Maybe (Vector n a) -> Maybe Int
guessDim = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 forall (t :: * -> *) a. Foldable t => t a -> Int
length

  guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (Vector n a) -> m [Int]
guessInnerDims = forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessInnerDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a (n :: Nat). KnownNat n => Vector n a -> Vector a
SV.SomeSized

  tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (Vector n a)
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
dims' = forall (v :: * -> *) (n :: Nat) a. v a -> Vector v n a
SVI.Vector forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
dims'

  tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Vector n a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
dims' = forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
dims' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (n :: Nat). KnownNat n => Vector n a -> Vector a
SV.SomeSized

sSetTensorOptions ::
  forall gradient layout device dataType gradientFrom layoutFrom deviceFrom dataTypeFrom shape.
  SGradient gradient ->
  SLayout layout ->
  SDevice device ->
  SDataType dataType ->
  Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape ->
  IO (Tensor gradient layout device dataType shape)
sSetTensorOptions :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (gradientFrom :: Gradient RequiresGradient)
       (layoutFrom :: Layout LayoutType)
       (deviceFrom :: Device (DeviceType Nat))
       (dataTypeFrom :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
-> IO (Tensor gradient layout device dataType shape)
sSetTensorOptions SGradient gradient
gradient' SLayout layout
layout SDevice device
device SDataType dataType
dataType Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
t =
  forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
t TensorOptions
opts Bool
nonBlocking Bool
copy
  where
    opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient' SLayout layout
layout SDevice device
device SDataType dataType
dataType

    nonBlocking :: Bool
nonBlocking = Bool
False
    copy :: Bool
copy = Bool
False

setTensorOptions ::
  forall gradient layout device dataType gradientFrom layoutFrom deviceFrom dataTypeFrom shape.
  ( SingI gradient,
    SingI layout,
    SingI device,
    SingI dataType
  ) =>
  Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape ->
  IO (Tensor gradient layout device dataType shape)
setTensorOptions :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (gradientFrom :: Gradient RequiresGradient)
       (layoutFrom :: Layout LayoutType)
       (deviceFrom :: Device (DeviceType Nat))
       (dataTypeFrom :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI gradient, SingI layout, SingI device, SingI dataType) =>
Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
-> IO (Tensor gradient layout device dataType shape)
setTensorOptions = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (gradientFrom :: Gradient RequiresGradient)
       (layoutFrom :: Layout LayoutType)
       (deviceFrom :: Device (DeviceType Nat))
       (dataTypeFrom :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
-> IO (Tensor gradient layout device dataType shape)
sSetTensorOptions (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType)

-- instance
--   ( SingI gradient,
--     SingI layout,
--     SingI device,
--     SingI dType
--   ) =>
--   TensorLike (Tensor gradient layout device ('DataType dType) ('Shape dims)) dType dims
--   where
--   sToTensor gradient' layout device t = pure $ sSetTensorOptions gradient' layout device dataType t
--     where
--       dataType = SDataType $ sing @dType

--   fromTensor = setTensorOptions @gradient @layout @device @('DataType dType)