{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE ViewPatterns #-}
module Torch.GraduallyTyped.Tensor.Type where
import Control.Applicative (empty)
import Control.Category ((>>>))
import Control.Exception (Exception (..))
import Control.Monad (forM, forM_, when, (<=<), (>=>))
import Control.Monad.Catch (MonadThrow (..))
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Functor ((<&>))
import Data.Int (Int16)
import Data.List.NonEmpty (NonEmpty ((:|)), nonEmpty, unzip)
import Data.Maybe (maybeToList)
import Data.Proxy (Proxy (..))
import Data.Singletons (SingI (sing), fromSing)
import Data.Typeable (Typeable)
import qualified Data.Vector as V hiding (uncons)
import qualified Data.Vector.Generic.Sized.Internal as SVI
import qualified Data.Vector.Sized as SV
import Foreign (Ptr, Word8, castPtr, fromBool, peekElemOff, pokeElemOff, withForeignPtr)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.Generics (Generic)
import GHC.TypeLits (KnownNat, KnownSymbol, Nat, Symbol, natVal, symbolVal)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..), SDeviceType (..))
import Torch.GraduallyTyped.Internal.TensorOptions (tensorOptions)
import qualified Torch.GraduallyTyped.Internal.Vector as V
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, ifM, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (InsertDimF, ReplaceDimF)
import Torch.GraduallyTyped.Shape.Type (By (ByIndex), Dim (..), Name (..), SDim (..), SName (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.HList (HList (..), pattern (:.))
import Torch.Internal.Cast (cast0, cast1, cast2, cast4)
import Torch.Internal.Class (Castable (..))
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.Extra as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen (Tensor, TensorList, TensorOptions)
import qualified Torch.Internal.Unmanaged.Type.Tensor as Unmanaged (tensor_data_ptr)
import qualified Torch.Tensor (Tensor (Unsafe))
import Prelude hiding (unzip, unzip3)
newtype
Tensor
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
where
UnsafeTensor ::
forall gradient layout device dataType shape.
ForeignPtr ATen.Tensor ->
Tensor gradient layout device dataType shape
type role Tensor nominal nominal nominal nominal nominal
instance Show (Tensor gradient layout device dataType shape) where
show :: Tensor gradient layout device dataType shape -> String
show (UnsafeTensor ForeignPtr Tensor
t) = forall a. Show a => a -> String
show (ForeignPtr Tensor -> Tensor
Torch.Tensor.Unsafe ForeignPtr Tensor
t)
data
TensorSpec
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
where
TensorSpec ::
forall gradient layout device dataType shape.
{ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SGradient gradient
tsGradient :: SGradient gradient,
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SLayout layout
tsLayout :: SLayout layout,
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SDevice device
tsDevice :: SDevice device,
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape
-> SDataType dataType
tsDataType :: SDataType dataType,
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsShape :: SShape shape
} ->
TensorSpec gradient layout device dataType shape
deriving stock (Int -> TensorSpec gradient layout device dataType shape -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int -> TensorSpec gradient layout device dataType shape -> ShowS
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
[TensorSpec gradient layout device dataType shape] -> ShowS
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> String
showList :: [TensorSpec gradient layout device dataType shape] -> ShowS
$cshowList :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
[TensorSpec gradient layout device dataType shape] -> ShowS
show :: TensorSpec gradient layout device dataType shape -> String
$cshow :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> String
showsPrec :: Int -> TensorSpec gradient layout device dataType shape -> ShowS
$cshowsPrec :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int -> TensorSpec gradient layout device dataType shape -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep (TensorSpec gradient layout device dataType shape) x
-> TensorSpec gradient layout device dataType shape
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
TensorSpec gradient layout device dataType shape
-> Rep (TensorSpec gradient layout device dataType shape) x
$cto :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
Rep (TensorSpec gradient layout device dataType shape) x
-> TensorSpec gradient layout device dataType shape
$cfrom :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) x.
TensorSpec gradient layout device dataType shape
-> Rep (TensorSpec gradient layout device dataType shape) x
Generic)
type UncheckedTensor = Tensor 'UncheckedGradient 'UncheckedLayout 'UncheckedDevice 'UncheckedDataType 'UncheckedShape
type UncheckedParameter = Tensor ('Gradient 'WithGradient) 'UncheckedLayout 'UncheckedDevice 'UncheckedDataType 'UncheckedShape
type CPUTensor = Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) ('Device 'CPU)
type CPUParameter = Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU)
type SparseCPUTensor = Tensor ('Gradient 'WithoutGradient) ('Layout 'Sparse) ('Device 'CPU)
type SparseCPUParameter = Tensor ('Gradient 'WithGradient) ('Layout 'Sparse) ('Device 'CPU)
type CUDATensor deviceId = Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) ('Device ('CUDA deviceId))
type CUDAParameter deviceId = Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device ('CUDA deviceId))
type SparseCUDATensor deviceId = Tensor ('Gradient 'WithoutGradient) ('Layout 'Sparse) ('Device ('CUDA deviceId))
type SparseCUDAParameter deviceId = Tensor ('Gradient 'WithGradient) ('Layout 'Sparse) ('Device ('CUDA deviceId))
instance Num (Tensor gradient layout device dataType shape) where
+ :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
(+) = (forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.add_tt
(-) = (forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sub_tt
* :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
(*) = (forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.mul_tt
negate :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
negate = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.neg_t
abs :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
abs = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.abs_t
signum :: Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
signum = forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.sign_t
fromInteger :: Integer -> Tensor gradient layout device dataType shape
fromInteger Integer
_a = forall a. HasCallStack => a
undefined
instance
Castable
(Tensor gradient layout device dataType shape)
(ForeignPtr ATen.Tensor)
where
cast :: forall r.
Tensor gradient layout device dataType shape
-> (ForeignPtr Tensor -> IO r) -> IO r
cast (UnsafeTensor ForeignPtr Tensor
atenTensor) ForeignPtr Tensor -> IO r
f = ForeignPtr Tensor -> IO r
f ForeignPtr Tensor
atenTensor
uncast :: forall r.
ForeignPtr Tensor
-> (Tensor gradient layout device dataType shape -> IO r) -> IO r
uncast ForeignPtr Tensor
atenTensor Tensor gradient layout device dataType shape -> IO r
f = Tensor gradient layout device dataType shape -> IO r
f forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor ForeignPtr Tensor
atenTensor
instance
Castable
[Tensor gradient layout device dataType shape]
(ForeignPtr ATen.TensorList)
where
cast :: forall r.
[Tensor gradient layout device dataType shape]
-> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor gradient layout device dataType shape]
xs ForeignPtr TensorList -> IO r
f = do
[ForeignPtr Tensor]
ptrList <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor gradient layout device dataType shape
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor gradient layout device dataType shape
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor gradient layout device dataType shape]
xs
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptrList ForeignPtr TensorList -> IO r
f
uncast :: forall r.
ForeignPtr TensorList
-> ([Tensor gradient layout device dataType shape] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor gradient layout device dataType shape] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptrList -> do
[Tensor gradient layout device dataType shape]
tensorList <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptrList
[Tensor gradient layout device dataType shape] -> IO r
f [Tensor gradient layout device dataType shape]
tensorList
instance Castable (HList '[]) [ForeignPtr ATen.Tensor] where
cast :: forall r. HList '[] -> ([ForeignPtr Tensor] -> IO r) -> IO r
cast HList '[]
R:HListk[] k
HNil [ForeignPtr Tensor] -> IO r
f = [ForeignPtr Tensor] -> IO r
f []
uncast :: forall r. [ForeignPtr Tensor] -> (HList '[] -> IO r) -> IO r
uncast [] HList '[] -> IO r
f = HList '[] -> IO r
f forall k. HList '[]
HNil
uncast (ForeignPtr Tensor
_ : [ForeignPtr Tensor]
_) HList '[] -> IO r
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"The list of tensors has more elements than expected. This means that the runtime length of the list exceeded its compile-time length."
instance
( Castable (HList tensors) [ForeignPtr ATen.Tensor]
) =>
Castable (HList (Tensor gradient layout device dataType shape ': tensors)) [ForeignPtr ATen.Tensor]
where
cast :: forall r.
HList (Tensor gradient layout device dataType shape : tensors)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
cast (HCons (Tensor gradient layout device dataType shape
tensor, HList tensors
tensors)) [ForeignPtr Tensor] -> IO r
f = do
ForeignPtr Tensor
ptr <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor gradient layout device dataType shape
tensor forall (f :: * -> *) a. Applicative f => a -> f a
pure
[ForeignPtr Tensor]
ptrList <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList tensors
tensors forall (f :: * -> *) a. Applicative f => a -> f a
pure
[ForeignPtr Tensor] -> IO r
f (ForeignPtr Tensor
ptr forall a. a -> [a] -> [a]
: [ForeignPtr Tensor]
ptrList)
uncast :: forall r.
[ForeignPtr Tensor]
-> (HList (Tensor gradient layout device dataType shape : tensors)
-> IO r)
-> IO r
uncast [] HList (Tensor gradient layout device dataType shape : tensors)
-> IO r
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"The list of tensors ended prematurely. This means that the runtime length of the list was smaller than its compile-time length."
uncast (ForeignPtr Tensor
ptr : [ForeignPtr Tensor]
ptrList) HList (Tensor gradient layout device dataType shape : tensors)
-> IO r
f = do
Tensor gradient layout device dataType shape
tensor <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
ptr forall (f :: * -> *) a. Applicative f => a -> f a
pure
HList tensors
tensors <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ForeignPtr Tensor]
ptrList forall (f :: * -> *) a. Applicative f => a -> f a
pure
HList (Tensor gradient layout device dataType shape : tensors)
-> IO r
f (Tensor gradient layout device dataType shape
tensor forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList tensors
tensors)
instance
Castable (HList l) [ForeignPtr ATen.Tensor] =>
Castable (HList l) (ForeignPtr ATen.TensorList)
where
cast :: forall r. HList l -> (ForeignPtr TensorList -> IO r) -> IO r
cast HList l
xs ForeignPtr TensorList -> IO r
f = do
[ForeignPtr Tensor]
ts <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList l
xs forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [ForeignPtr ATen.Tensor]
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ts ForeignPtr TensorList -> IO r
f
uncast :: forall r. ForeignPtr TensorList -> (HList l -> IO r) -> IO r
uncast ForeignPtr TensorList
xs HList l -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \([ForeignPtr Tensor]
ptrList :: [ForeignPtr ATen.Tensor]) -> do
HList l
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ForeignPtr Tensor]
ptrList forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (HList l)
HList l -> IO r
f HList l
ts
withoutGradient ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
IO (Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
(Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient Tensor gradient layout device dataType shape
tensor = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient layout device dataType shape
tensor Bool
False
withGradient ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
IO (Tensor ('Gradient 'WithGradient) layout device dataType shape)
withGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
(Tensor ('Gradient 'WithGradient) layout device dataType shape)
withGradient Tensor gradient layout device dataType shape
tensor = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient layout device dataType shape
tensor Bool
True
sSetGradient ::
forall gradient gradient' layout device dataType shape.
SGradient gradient ->
Tensor gradient' layout device dataType shape ->
IO (Tensor gradient layout device dataType shape)
sSetGradient :: forall (gradient :: Gradient RequiresGradient)
(gradient' :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> Tensor gradient' layout device dataType shape
-> IO (Tensor gradient layout device dataType shape)
sSetGradient SGradient gradient
gradient Tensor gradient' layout device dataType shape
tensor =
case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SGradient gradient
gradient) of
RequiresGradient
WithoutGradient -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient' layout device dataType shape
tensor Bool
False
RequiresGradient
WithGradient -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor gradient' layout device dataType shape
tensor Bool
True
class SGetGradient (gradient :: Gradient RequiresGradient) where
sGetGradient ::
forall layout device dataType shape.
Tensor gradient layout device dataType shape ->
SGradient gradient
getRequiresGradient ::
forall layout device dataType shape.
Tensor gradient layout device dataType shape ->
RequiresGradient
getRequiresGradient Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetGradient gradient =>
Tensor gradient layout device dataType shape -> SGradient gradient
sGetGradient Tensor gradient layout device dataType shape
tensor
instance SGetGradient 'UncheckedGradient where
sGetGradient :: forall (layout :: Layout LayoutType)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor 'UncheckedGradient layout device dataType shape
-> SGradient 'UncheckedGradient
sGetGradient Tensor 'UncheckedGradient layout device dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor 'UncheckedGradient layout device dataType shape
tensor) = RequiresGradient -> SGradient 'UncheckedGradient
SUncheckedGradient RequiresGradient
WithGradient
| Bool
otherwise = RequiresGradient -> SGradient 'UncheckedGradient
SUncheckedGradient RequiresGradient
WithoutGradient
instance SGetGradient ('Gradient 'WithGradient) where
sGetGradient :: forall (layout :: Layout LayoutType)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor ('Gradient 'WithGradient) layout device dataType shape
-> SGradient ('Gradient 'WithGradient)
sGetGradient Tensor ('Gradient 'WithGradient) layout device dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor ('Gradient 'WithGradient) layout device dataType shape
tensor) = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient
| Bool
otherwise =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should require gradient computations but doesn't. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
instance SGetGradient ('Gradient 'WithoutGradient) where
sGetGradient :: forall (layout :: Layout LayoutType)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor ('Gradient 'WithoutGradient) layout device dataType shape
-> SGradient ('Gradient 'WithoutGradient)
sGetGradient Tensor ('Gradient 'WithoutGradient) layout device dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor ('Gradient 'WithoutGradient) layout device dataType shape
tensor) =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should not require gradient computations but does. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
| Bool
otherwise = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
data GradientError = GradientError {GradientError -> RequiresGradient
geExpected :: RequiresGradient, GradientError -> RequiresGradient
geActual :: RequiresGradient}
deriving stock (Int -> GradientError -> ShowS
[GradientError] -> ShowS
GradientError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GradientError] -> ShowS
$cshowList :: [GradientError] -> ShowS
show :: GradientError -> String
$cshow :: GradientError -> String
showsPrec :: Int -> GradientError -> ShowS
$cshowsPrec :: Int -> GradientError -> ShowS
Show, Typeable)
instance Exception GradientError where
displayException :: GradientError -> String
displayException GradientError {RequiresGradient
geActual :: RequiresGradient
geExpected :: RequiresGradient
geActual :: GradientError -> RequiresGradient
geExpected :: GradientError -> RequiresGradient
..} =
String
"The tensor's information about whether or not gradient computations are required reads `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RequiresGradient
geActual
forall a. Semigroup a => a -> a -> a
<> String
"` instead of `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RequiresGradient
geExpected
forall a. Semigroup a => a -> a -> a
<> String
"`."
sCheckedGradient ::
forall gradient' m gradient layout device dataType shape.
(SGetGradient gradient, MonadThrow m, Catch (gradient <+> gradient')) =>
SGradient gradient' ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient' layout device dataType shape)
sCheckedGradient :: forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient SGradient gradient'
gradient' Tensor gradient layout device dataType shape
tensor =
let actualGradient :: RequiresGradient
actualGradient = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetGradient gradient =>
Tensor gradient layout device dataType shape -> SGradient gradient
sGetGradient Tensor gradient layout device dataType shape
tensor
expectedGradient :: RequiresGradient
expectedGradient = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SGradient gradient'
gradient'
in if RequiresGradient
actualGradient forall a. Eq a => a -> a -> Bool
== RequiresGradient
expectedGradient
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ RequiresGradient -> RequiresGradient -> GradientError
GradientError RequiresGradient
expectedGradient RequiresGradient
actualGradient
checkedGradient ::
forall gradient' m gradient layout device dataType shape.
(SingI gradient', SGetGradient gradient, MonadThrow m, Catch (gradient <+> gradient')) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient' layout device dataType shape)
checkedGradient :: forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI gradient', SGetGradient gradient, MonadThrow m,
Catch (gradient <+> gradient')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
checkedGradient = forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall {k} (a :: k). SingI a => Sing a
sing @gradient')
uncheckedGradient ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor 'UncheckedGradient layout device dataType shape
uncheckedGradient :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor 'UncheckedGradient layout device dataType shape
uncheckedGradient = coerce :: forall a b. Coercible a b => a -> b
coerce
toDense ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient ('Layout 'Dense) device dataType shape)
toDense :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient ('Layout 'Dense) device dataType shape)
toDense = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense
toSparse ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient ('Layout 'Sparse) device dataType shape)
toSparse :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient ('Layout 'Sparse) device dataType shape)
toSparse = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse
sSetLayout ::
forall m gradient layout layout' device dataType shape.
MonadThrow m =>
SLayout layout ->
Tensor gradient layout' device dataType shape ->
m (Tensor gradient layout device dataType shape)
sSetLayout :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (layout' :: Layout LayoutType)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SLayout layout
-> Tensor gradient layout' device dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetLayout SLayout layout
layout Tensor gradient layout' device dataType shape
input =
case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SLayout layout
layout) of
LayoutType
Dense -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense forall a b. (a -> b) -> a -> b
$ Tensor gradient layout' device dataType shape
input
LayoutType
Sparse -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse forall a b. (a -> b) -> a -> b
$ Tensor gradient layout' device dataType shape
input
class SGetLayout (layout :: Layout LayoutType) where
sGetLayout ::
forall gradient device dataType shape.
Tensor gradient layout device dataType shape ->
SLayout layout
getLayoutType ::
forall gradient device dataType shape.
Tensor gradient layout device dataType shape ->
LayoutType
getLayoutType Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (layout :: Layout LayoutType)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
tensor
instance SGetLayout 'UncheckedLayout where
sGetLayout :: forall (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient 'UncheckedLayout device dataType shape
-> SLayout 'UncheckedLayout
sGetLayout Tensor gradient 'UncheckedLayout device dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_sparse Tensor gradient 'UncheckedLayout device dataType shape
tensor) = LayoutType -> SLayout 'UncheckedLayout
SUncheckedLayout LayoutType
Sparse
| Bool
otherwise = LayoutType -> SLayout 'UncheckedLayout
SUncheckedLayout LayoutType
Dense
instance SGetLayout ('Layout 'Sparse) where
sGetLayout :: forall (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Sparse) device dataType shape
-> SLayout ('Layout 'Sparse)
sGetLayout Tensor gradient ('Layout 'Sparse) device dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_sparse Tensor gradient ('Layout 'Sparse) device dataType shape
tensor) = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Sparse
SSparse
| Bool
otherwise =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should be sparse but isn't. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
instance SGetLayout ('Layout 'Dense) where
sGetLayout :: forall (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient ('Layout 'Dense) device dataType shape
-> SLayout ('Layout 'Dense)
sGetLayout Tensor gradient ('Layout 'Dense) device dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_sparse Tensor gradient ('Layout 'Dense) device dataType shape
tensor) =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should be dense but isn't. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
| Bool
otherwise = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense
data LayoutError = LayoutError {LayoutError -> LayoutType
leExpected :: LayoutType, LayoutError -> LayoutType
leActual :: LayoutType}
deriving stock (Int -> LayoutError -> ShowS
[LayoutError] -> ShowS
LayoutError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LayoutError] -> ShowS
$cshowList :: [LayoutError] -> ShowS
show :: LayoutError -> String
$cshow :: LayoutError -> String
showsPrec :: Int -> LayoutError -> ShowS
$cshowsPrec :: Int -> LayoutError -> ShowS
Show, Typeable)
instance Exception LayoutError where
displayException :: LayoutError -> String
displayException LayoutError {LayoutType
leActual :: LayoutType
leExpected :: LayoutType
leActual :: LayoutError -> LayoutType
leExpected :: LayoutError -> LayoutType
..} =
String
"The tensor does not have the memory layout `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LayoutType
leExpected
forall a. Semigroup a => a -> a -> a
<> String
"` but `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LayoutType
leActual
forall a. Semigroup a => a -> a -> a
<> String
"`."
sCheckedLayout ::
forall layout' m gradient layout device dataType shape.
(SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
SLayout layout' ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout' device dataType shape)
sCheckedLayout :: forall (layout' :: Layout LayoutType) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
SLayout layout'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout' device dataType shape)
sCheckedLayout SLayout layout'
layout' Tensor gradient layout device dataType shape
tensor =
let actualLayout :: LayoutType
actualLayout = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (layout :: Layout LayoutType)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
tensor
expectedLayout :: LayoutType
expectedLayout = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SLayout layout'
layout'
in if LayoutType
actualLayout forall a. Eq a => a -> a -> Bool
== LayoutType
expectedLayout
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ LayoutType -> LayoutType -> LayoutError
LayoutError LayoutType
expectedLayout LayoutType
actualLayout
checkedLayout ::
forall layout' m gradient layout device dataType shape.
(SingI layout', SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout' device dataType shape)
checkedLayout :: forall (layout' :: Layout LayoutType) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI layout', SGetLayout layout, MonadThrow m,
Catch (layout <+> layout')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout' device dataType shape)
checkedLayout = forall (layout' :: Layout LayoutType) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, MonadThrow m, Catch (layout <+> layout')) =>
SLayout layout'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout' device dataType shape)
sCheckedLayout (forall {k} (a :: k). SingI a => Sing a
sing @layout')
uncheckedLayout ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient 'UncheckedLayout device dataType shape
uncheckedLayout :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient 'UncheckedLayout device dataType shape
uncheckedLayout = coerce :: forall a b. Coercible a b => a -> b
coerce
cpu ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout ('Device 'CPU) dataType shape)
cpu :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout ('Device 'CPU) dataType shape)
cpu = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu
cuda ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout ('Device ('CUDA 0)) dataType shape)
cuda :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout ('Device ('CUDA 0)) dataType shape)
cuda = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda
sSetDevice ::
forall m gradient layout device device' dataType shape.
MonadThrow m =>
SDevice device ->
Tensor gradient layout device' dataType shape ->
m (Tensor gradient layout device dataType shape)
sSetDevice :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDevice device
-> Tensor gradient layout device' dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetDevice SDevice device
device Tensor gradient layout device' dataType shape
input =
case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDevice device
device) of
DeviceType Int16
CPU ->
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device' dataType shape
input
CUDA Int16
0 ->
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device' dataType shape
input
CUDA Int16
idx ->
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorOptions
opts :: ForeignPtr ATen.TensorOptions <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options Tensor gradient layout device' dataType shape
input
ForeignPtr TensorOptions
opts' :: ForeignPtr ATen.TensorOptions <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s ForeignPtr TensorOptions
opts Int16
idx
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor gradient layout device' dataType shape
input ForeignPtr TensorOptions
opts' Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
class SGetDevice (device :: Device (DeviceType Nat)) where
sGetDevice ::
forall gradient layout dataType shape.
Tensor gradient layout device dataType shape ->
SDevice device
getDeviceType ::
forall gradient layout dataType shape.
Tensor gradient layout device dataType shape ->
DeviceType Int16
getDeviceType Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
tensor
instance SGetDevice 'UncheckedDevice where
sGetDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout 'UncheckedDevice dataType shape
-> SDevice 'UncheckedDevice
sGetDevice Tensor gradient layout 'UncheckedDevice dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor gradient layout 'UncheckedDevice dataType shape
tensor) =
case forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor gradient layout 'UncheckedDevice dataType shape
tensor) :: Int of
Int
deviceIndex -> DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall deviceId. deviceId -> DeviceType deviceId
CUDA forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
deviceIndex
| Bool
otherwise = DeviceType Int16 -> SDevice 'UncheckedDevice
SUncheckedDevice forall deviceId. DeviceType deviceId
CPU
instance SGetDevice ('Device 'CPU) where
sGetDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout ('Device 'CPU) dataType shape
-> SDevice ('Device 'CPU)
sGetDevice Tensor gradient layout ('Device 'CPU) dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor gradient layout ('Device 'CPU) dataType shape
tensor) =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should be on CPU but is on CUDA. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
| Bool
otherwise = forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU
instance KnownNat deviceIndex => SGetDevice ('Device ('CUDA deviceIndex)) where
sGetDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
-> SDevice ('Device ('CUDA deviceIndex))
sGetDevice Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA) Bool -> Bool -> Bool
&& forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
tensor) =
case forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor gradient layout ('Device ('CUDA deviceIndex)) dataType shape
tensor) :: Int of
Int
deviceIndex
| Int
deviceIndex forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @deviceIndex)) -> forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice forall (deviceId :: Nat).
KnownNat deviceId =>
SDeviceType ('CUDA deviceId)
SCUDA
| Bool
otherwise ->
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should be on CUDA device "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @deviceIndex))
forall a. Semigroup a => a -> a -> a
<> String
" but is on device "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
deviceIndex
forall a. Semigroup a => a -> a -> a
<> String
". "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
| Bool
otherwise =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should be on CUDA but is on CPU. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
data DeviceError = DeviceError {DeviceError -> DeviceType Int16
deExpected :: DeviceType Int16, DeviceError -> DeviceType Int16
deActual :: DeviceType Int16}
deriving stock (Int -> DeviceError -> ShowS
[DeviceError] -> ShowS
DeviceError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DeviceError] -> ShowS
$cshowList :: [DeviceError] -> ShowS
show :: DeviceError -> String
$cshow :: DeviceError -> String
showsPrec :: Int -> DeviceError -> ShowS
$cshowsPrec :: Int -> DeviceError -> ShowS
Show, Typeable)
instance Exception DeviceError where
displayException :: DeviceError -> String
displayException DeviceError {DeviceType Int16
deActual :: DeviceType Int16
deExpected :: DeviceType Int16
deActual :: DeviceError -> DeviceType Int16
deExpected :: DeviceError -> DeviceType Int16
..} =
String
"The tensor is not in the memory of the device `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DeviceType Int16
deExpected
forall a. Semigroup a => a -> a -> a
<> String
"` but `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DeviceType Int16
deActual
forall a. Semigroup a => a -> a -> a
<> String
"`."
sCheckedDevice ::
forall device' m gradient layout device dataType shape.
(SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
SDevice device' ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device' dataType shape)
sCheckedDevice :: forall (device' :: Device (DeviceType Nat)) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
SDevice device'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device' dataType shape)
sCheckedDevice SDevice device'
device' Tensor gradient layout device dataType shape
tensor =
let actualDevice :: DeviceType Int16
actualDevice = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
tensor
expectedDevice :: DeviceType Int16
expectedDevice = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDevice device'
device'
in if DeviceType Int16
actualDevice forall a. Eq a => a -> a -> Bool
== DeviceType Int16
expectedDevice
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ DeviceType Int16 -> DeviceType Int16 -> DeviceError
DeviceError DeviceType Int16
expectedDevice DeviceType Int16
actualDevice
checkedDevice ::
forall device' m gradient layout device dataType shape.
(SingI device', SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device' dataType shape)
checkedDevice :: forall (device' :: Device (DeviceType Nat)) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI device', SGetDevice device, MonadThrow m,
Catch (device <+> device')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device' dataType shape)
checkedDevice = forall (device' :: Device (DeviceType Nat)) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDevice device, MonadThrow m, Catch (device <+> device')) =>
SDevice device'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device' dataType shape)
sCheckedDevice (forall {k} (a :: k). SingI a => Sing a
sing @device')
uncheckedDevice ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout 'UncheckedDevice dataType shape
uncheckedDevice :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout 'UncheckedDevice dataType shape
uncheckedDevice = coerce :: forall a b. Coercible a b => a -> b
coerce
bool ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Bool) shape)
bool :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient) layout device ('DataType 'Bool) shape)
bool Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Bool
byte ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'UInt8) shape)
byte :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'UInt8)
shape)
byte Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
UInt8
char ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int8) shape)
char :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient) layout device ('DataType 'Int8) shape)
char Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int8
short ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int16) shape)
short :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Int16)
shape)
short Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int16
int ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int32) shape)
int :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Int32)
shape)
int Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int32
long ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape)
long :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Int64)
shape)
long Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Int64
half ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device ('DataType 'Half) shape)
half :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Half) shape)
half Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Half
float ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device ('DataType 'Float) shape)
float :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Float) shape)
float Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Float
double ::
forall m gradient layout device dataType shape.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device ('DataType 'Double) shape)
double :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device ('DataType 'Double) shape)
double Tensor gradient layout device dataType shape
tensor = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor gradient layout device dataType shape
tensor DType
Double
sSetDataType ::
forall m gradient layout device dataType dataType' shape.
MonadThrow m =>
SDataType dataType ->
Tensor gradient layout device dataType' shape ->
m (Tensor gradient layout device dataType shape)
sSetDataType :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (dataType' :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDataType dataType
-> Tensor gradient layout device dataType' shape
-> m (Tensor gradient layout device dataType shape)
sSetDataType SDataType dataType
dataType Tensor gradient layout device dataType' shape
input =
case forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDataType dataType
dataType) of
DType
dType -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorOptions
opts :: ForeignPtr ATen.TensorOptions <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options Tensor gradient layout device dataType' shape
input
ForeignPtr TensorOptions
opts' :: ForeignPtr ATen.TensorOptions <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions
-> ScalarType -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_dtype_s ForeignPtr TensorOptions
opts DType
dType
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor gradient layout device dataType' shape
input ForeignPtr TensorOptions
opts' Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
class SGetDataType (dataType :: DataType DType) where
sGetDataType ::
forall gradient layout device shape.
Tensor gradient layout device dataType shape ->
SDataType dataType
getDType ::
forall gradient layout device shape.
Tensor gradient layout device dataType shape ->
DType
getDType Tensor gradient layout device dataType shape
tensor = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (dataType :: DataType DType)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDataType dataType =>
Tensor gradient layout device dataType shape -> SDataType dataType
sGetDataType Tensor gradient layout device dataType shape
tensor
instance SGetDataType 'UncheckedDataType where
sGetDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device 'UncheckedDataType shape
-> SDataType 'UncheckedDataType
sGetDataType Tensor gradient layout device 'UncheckedDataType shape
tensor = DType -> SDataType 'UncheckedDataType
SUncheckedDataType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor gradient layout device 'UncheckedDataType shape
tensor
instance SingI dType => SGetDataType ('DataType dType) where
sGetDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device ('DataType dType) shape
-> SDataType ('DataType dType)
sGetDataType Tensor gradient layout device ('DataType dType) shape
tensor
| forall a. IO a -> a
unsafePerformIO (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor gradient layout device ('DataType dType) shape
tensor) forall a. Eq a => a -> a -> Bool
== forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing (forall {k} (a :: k). SingI a => Sing a
sing @dType) = forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dType
| Bool
otherwise =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The tensor should have data type "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dType)
forall a. Semigroup a => a -> a -> a
<> String
" but hasn't. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
data DataTypeError = DataTypeError {DataTypeError -> DType
dtExpected :: DType, DataTypeError -> DType
dtActual :: DType}
deriving stock (Int -> DataTypeError -> ShowS
[DataTypeError] -> ShowS
DataTypeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DataTypeError] -> ShowS
$cshowList :: [DataTypeError] -> ShowS
show :: DataTypeError -> String
$cshow :: DataTypeError -> String
showsPrec :: Int -> DataTypeError -> ShowS
$cshowsPrec :: Int -> DataTypeError -> ShowS
Show, Typeable)
instance Exception DataTypeError where
displayException :: DataTypeError -> String
displayException DataTypeError {DType
dtActual :: DType
dtExpected :: DType
dtActual :: DataTypeError -> DType
dtExpected :: DataTypeError -> DType
..} =
String
"The tensor does not have the data type `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DType
dtExpected
forall a. Semigroup a => a -> a -> a
<> String
"` but `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DType
dtActual
forall a. Semigroup a => a -> a -> a
<> String
"`."
sCheckedDataType ::
forall dataType' m gradient layout device dataType shape.
(SGetDataType dataType, MonadThrow m, Catch (dataType <+> dataType')) =>
SDataType dataType' ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType' shape)
sCheckedDataType :: forall (dataType' :: DataType DType) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDataType dataType, MonadThrow m,
Catch (dataType <+> dataType')) =>
SDataType dataType'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType' shape)
sCheckedDataType SDataType dataType'
dataType' Tensor gradient layout device dataType shape
tensor =
let actualDataType :: DType
actualDataType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (dataType :: DataType DType)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDataType dataType =>
Tensor gradient layout device dataType shape -> SDataType dataType
sGetDataType Tensor gradient layout device dataType shape
tensor
expectedDataType :: DType
expectedDataType = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDataType dataType'
dataType'
in if DType
actualDataType forall a. Eq a => a -> a -> Bool
== DType
expectedDataType
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ DType -> DType -> DataTypeError
DataTypeError DType
expectedDataType DType
actualDataType
checkedDataType ::
forall dataType' m gradient layout device dataType shape.
(SingI dataType', SGetDataType dataType, MonadThrow m, Catch (dataType <+> dataType')) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType' shape)
checkedDataType :: forall (dataType' :: DataType DType) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI dataType', SGetDataType dataType, MonadThrow m,
Catch (dataType <+> dataType')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType' shape)
checkedDataType = forall (dataType' :: DataType DType) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetDataType dataType, MonadThrow m,
Catch (dataType <+> dataType')) =>
SDataType dataType'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType' shape)
sCheckedDataType (forall {k} (a :: k). SingI a => Sing a
sing @dataType')
uncheckedDataType ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout device 'UncheckedDataType shape
uncheckedDataType :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device 'UncheckedDataType shape
uncheckedDataType = coerce :: forall a b. Coercible a b => a -> b
coerce
class SGetShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]) where
sGetShape ::
forall gradient layout device dataType.
Tensor gradient layout device dataType shape ->
SShape shape
getDims ::
forall gradient layout device dataType.
Tensor gradient layout device dataType shape ->
[Dim String Integer]
getDims Tensor gradient layout device dataType shape
tensor = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. IsChecked a -> a
forgetIsChecked forall a. IsChecked a -> a
forgetIsChecked) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
tensor
instance SGetShape 'UncheckedShape where
sGetShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
Tensor gradient layout device dataType 'UncheckedShape
-> SShape 'UncheckedShape
sGetShape Tensor gradient layout device dataType 'UncheckedShape
tensor = [Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
[Integer]
sizes <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes Tensor gradient layout device dataType 'UncheckedShape
tensor
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM
(forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_has_names Tensor gradient layout device dataType 'UncheckedShape
tensor)
( do
[String]
names <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr DimnameList)
ATen.tensor_names Tensor gradient layout device dataType 'UncheckedShape
tensor
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall name size. name -> size -> Dim name size
Dim [String]
names [Integer]
sizes
)
(forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall name size. name -> size -> Dim name size
Dim String
"*" forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Integer]
sizes)
instance SGetDims dims => SGetShape ('Shape dims) where
sGetShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
Tensor gradient layout device dataType ('Shape dims)
-> SShape ('Shape dims)
sGetShape Tensor gradient layout device dataType ('Shape dims)
tensor =
let sizes :: [Integer]
sizes =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM
((forall a. Ord a => a -> a -> Bool
> (Int
0 :: Int)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim Tensor gradient layout device dataType ('Shape dims)
tensor)
(forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes Tensor gradient layout device dataType ('Shape dims)
tensor)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
names :: [String]
names =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM
(forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_has_names Tensor gradient layout device dataType ('Shape dims)
tensor)
(forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr DimnameList)
ATen.tensor_names Tensor gradient layout device dataType ('Shape dims)
tensor)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const String
"*") [Integer]
sizes)
in forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SGetDims dims =>
[String] -> [Integer] -> SList dims
sGetDims [String]
names [Integer]
sizes
class SGetDims (dims :: [Dim (Name Symbol) (Size Nat)]) where
sGetDims :: [String] -> [Integer] -> SList dims
dimsError :: forall a. a
dimsError :: forall a. a
dimsError = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"The numbers of compile- and runtime dimensions are not the same. " forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
dimNameError :: forall a. String -> String -> a
dimNameError :: forall a. String -> String -> a
dimNameError String
name String
name' =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The compile- and runtime dimension names are not the same, '"
forall a. Semigroup a => a -> a -> a
<> String
name
forall a. Semigroup a => a -> a -> a
<> String
"' != '"
forall a. Semigroup a => a -> a -> a
<> String
name'
forall a. Semigroup a => a -> a -> a
<> String
"'. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
dimSizeError :: forall a b. Show a => a -> a -> b
dimSizeError :: forall a b. Show a => a -> a -> b
dimSizeError a
size a
size' =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The compile- and runtime dimension sizes are not the same, '"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size
forall a. Semigroup a => a -> a -> a
<> String
"' != '"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size'
forall a. Semigroup a => a -> a -> a
<> String
"'. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
dimNameSizeError :: forall a b. Show a => String -> String -> a -> a -> b
dimNameSizeError :: forall a b. Show a => String -> String -> a -> a -> b
dimNameSizeError String
name String
name' a
size a
size' =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"The compile- and runtime dimension names and sizes are not the same, '"
forall a. Semigroup a => a -> a -> a
<> String
name
forall a. Semigroup a => a -> a -> a
<> String
"' != '"
forall a. Semigroup a => a -> a -> a
<> String
name'
forall a. Semigroup a => a -> a -> a
<> String
"' and '"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size
forall a. Semigroup a => a -> a -> a
<> String
"' != '"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
size'
forall a. Semigroup a => a -> a -> a
<> String
"'. "
forall a. Semigroup a => a -> a -> a
<> String
gitHubErrorMsg
instance SGetDims '[] where
sGetDims :: [String] -> [Integer] -> SList '[]
sGetDims [] [] = forall a. SList '[]
SNil
sGetDims [String]
_ [Integer]
_ = forall a. a
dimsError
instance (SGetDim dim, SGetDims dims) => SGetDims (dim : dims) where
sGetDims :: [String] -> [Integer] -> SList (dim : dims)
sGetDims (String
name : [String]
names) (Integer
size : [Integer]
sizes) = forall (dim :: Dim (Name Symbol) (Size Nat)).
SGetDim dim =>
String -> Integer -> SDim dim
sGetDim String
name Integer
size forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SGetDims dims =>
[String] -> [Integer] -> SList dims
sGetDims [String]
names [Integer]
sizes
sGetDims [String]
_ [Integer]
_ = forall a. a
dimsError
class SGetDim (dim :: Dim (Name Symbol) (Size Nat)) where
sGetDim :: String -> Integer -> SDim dim
instance SGetDim ('Dim 'UncheckedName 'UncheckedSize) where
sGetDim :: String -> Integer -> SDim ('Dim 'UncheckedName 'UncheckedSize)
sGetDim String
name Integer
size = forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (String -> SName 'UncheckedName
SUncheckedName String
name) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)
instance KnownSymbol name => SGetDim ('Dim ('Name name) 'UncheckedSize) where
sGetDim :: String -> Integer -> SDim ('Dim ('Name name) 'UncheckedSize)
sGetDim String
name Integer
size = case forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name of
String
name'
| String
name forall a. Eq a => a -> a -> Bool
== String
name' -> forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @name) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)
| Bool
otherwise -> forall a. String -> String -> a
dimNameError String
name String
name'
instance KnownNat size => SGetDim ('Dim 'UncheckedName ('Size size)) where
sGetDim :: String -> Integer -> SDim ('Dim 'UncheckedName ('Size size))
sGetDim String
name Integer
size = case forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @size of
Integer
size'
| Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' -> forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (String -> SName 'UncheckedName
SUncheckedName String
name) (forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @size)
| Bool
otherwise -> forall a b. Show a => a -> a -> b
dimSizeError Integer
size Integer
size'
instance (KnownSymbol name, KnownNat size) => SGetDim ('Dim ('Name name) ('Size size)) where
sGetDim :: String -> Integer -> SDim ('Dim ('Name name) ('Size size))
sGetDim String
name Integer
size = case (forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name, forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @size) of
(String
name', Integer
size')
| String
name forall a. Eq a => a -> a -> Bool
== String
name' Bool -> Bool -> Bool
&& Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' -> forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @name) (forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @size)
| String
name forall a. Eq a => a -> a -> Bool
/= String
name' Bool -> Bool -> Bool
&& Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' -> forall a. String -> String -> a
dimNameError String
name String
name'
| String
name forall a. Eq a => a -> a -> Bool
== String
name' Bool -> Bool -> Bool
&& Integer
size forall a. Eq a => a -> a -> Bool
/= Integer
size' -> forall a b. Show a => a -> a -> b
dimSizeError Integer
size Integer
size'
| Bool
otherwise -> forall a b. Show a => String -> String -> a -> a -> b
dimNameSizeError String
name String
name' Integer
size Integer
size'
data ShapeError = ShapeError {ShapeError -> [Dim String Integer]
seExpected :: [Dim String Integer], ShapeError -> [Dim String Integer]
seActual :: [Dim String Integer]}
deriving stock (Int -> ShapeError -> ShowS
[ShapeError] -> ShowS
ShapeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ShapeError] -> ShowS
$cshowList :: [ShapeError] -> ShowS
show :: ShapeError -> String
$cshow :: ShapeError -> String
showsPrec :: Int -> ShapeError -> ShowS
$cshowsPrec :: Int -> ShapeError -> ShowS
Show)
instance Exception ShapeError where
displayException :: ShapeError -> String
displayException ShapeError {[Dim String Integer]
seActual :: [Dim String Integer]
seExpected :: [Dim String Integer]
seActual :: ShapeError -> [Dim String Integer]
seExpected :: ShapeError -> [Dim String Integer]
..} =
String
"The tensor does not have the shape `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Dim String Integer]
seExpected
forall a. Semigroup a => a -> a -> a
<> String
"` but `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Dim String Integer]
seActual
forall a. Semigroup a => a -> a -> a
<> String
"`."
sCheckedShape ::
forall shape' m gradient layout device dataType shape.
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape' ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
sCheckedShape :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape SShape shape'
shape' Tensor gradient layout device dataType shape
tensor =
let f :: Sing a -> [Dim String Integer]
f = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing
actualShape :: [Dim String Integer]
actualShape = forall {a :: Shape [Dim (Name Symbol) (Size Nat)]}.
Sing a -> [Dim String Integer]
f forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
tensor
expectedShape :: [Dim String Integer]
expectedShape = forall {a :: Shape [Dim (Name Symbol) (Size Nat)]}.
Sing a -> [Dim String Integer]
f SShape shape'
shape'
in if [Dim String Integer]
actualShape forall a. Eq a => a -> a -> Bool
== [Dim String Integer]
expectedShape
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall a b. (a -> b) -> a -> b
$ Tensor gradient layout device dataType shape
tensor
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Dim String Integer] -> [Dim String Integer] -> ShapeError
ShapeError [Dim String Integer]
expectedShape [Dim String Integer]
actualShape
checkedShape ::
forall shape' m gradient layout device dataType shape.
(SingI shape', SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
checkedShape :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI shape', SGetShape shape, MonadThrow m,
Catch (shape <+> shape')) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
checkedShape = forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall {k} (a :: k). SingI a => Sing a
sing @shape')
uncheckedDim ::
forall selectDim gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType (ReplaceDimF selectDim shape ('Dim 'UncheckedName 'UncheckedSize))
uncheckedDim :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor
gradient
layout
device
dataType
(ReplaceDimF selectDim shape ('Dim 'UncheckedName 'UncheckedSize))
uncheckedDim = coerce :: forall a b. Coercible a b => a -> b
coerce
uncheckedShape ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType 'UncheckedShape
uncheckedShape :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType 'UncheckedShape
uncheckedShape = coerce :: forall a b. Coercible a b => a -> b
coerce
gitHubErrorMsg :: String
gitHubErrorMsg :: String
gitHubErrorMsg = String
"Please open a ticket on GitHub."
isContiguous ::
Tensor gradient layout device dataType shape ->
Bool
isContiguous :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape -> Bool
isContiguous Tensor gradient layout device dataType shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_contiguous Tensor gradient layout device dataType shape
t
contiguous ::
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape
contiguous :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
contiguous Tensor gradient layout device dataType shape
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_contiguous Tensor gradient layout device dataType shape
t
withTensor :: Tensor gradient layout device dataType shape -> (Ptr () -> IO a) -> IO a
withTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) a.
Tensor gradient layout device dataType shape
-> (Ptr () -> IO a) -> IO a
withTensor Tensor gradient layout device dataType shape
t Ptr () -> IO a
fn =
let contiguousTensor :: Tensor gradient layout device dataType shape
contiguousTensor = if forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape -> Bool
isContiguous Tensor gradient layout device dataType shape
t then Tensor gradient layout device dataType shape
t else forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
contiguous Tensor gradient layout device dataType shape
t
in forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor gradient layout device dataType shape
contiguousTensor forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
ct -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
ct forall a b. (a -> b) -> a -> b
$ Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Ptr () -> IO a
fn
class TensorLikeRaw a where
guessDim ::
Maybe a ->
Maybe Int
guessInnerDims ::
MonadThrow m =>
Maybe a ->
m [Int]
tensorPeekElemOff ::
Ptr () ->
Int ->
[Int] ->
IO a
tensorPokeElemOff ::
Ptr () ->
Int ->
[Int] ->
a ->
IO ()
guessDims :: forall a m. (TensorLikeRaw a, MonadThrow m) => Maybe a -> m [Int]
guessDims :: forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x = ([Int]
outerDim forall a. Semigroup a => a -> a -> a
<>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessInnerDims Maybe a
x
where
outerDim :: [Int]
outerDim = forall a. Maybe a -> [a]
maybeToList forall a b. (a -> b) -> a -> b
$ forall a. TensorLikeRaw a => Maybe a -> Maybe Int
guessDim Maybe a
x
unexpectedDimsError :: forall a m b. (TensorLikeRaw a, MonadThrow m) => [Int] -> Maybe a -> m b
unexpectedDimsError :: forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' Maybe a
x = do
[Int]
expected <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Expected shape to be " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
expected forall a. Semigroup a => a -> a -> a
<> String
" got: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
dims'
class TensorLike a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]) | a -> dims, a -> dType where
sToTensor ::
forall gradient layout device m.
MonadThrow m =>
SGradient gradient ->
SLayout layout ->
SDevice device ->
a ->
m (Tensor gradient layout device ('DataType dType) ('Shape dims))
fromTensor ::
forall gradient layout device.
Tensor gradient layout device ('DataType dType) ('Shape dims) ->
a
toTensor ::
forall gradient layout device a dType dims m.
( TensorLike a dType dims,
SingI gradient,
SingI layout,
SingI device,
MonadThrow m
) =>
a ->
m (Tensor gradient layout device ('DataType dType) ('Shape dims))
toTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, SingI gradient, SingI layout,
SingI device, MonadThrow m) =>
a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
toTensor = forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensor (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device)
sToTensorRaw ::
forall gradient layout device a dType dims m.
(TensorLike a dType dims, TensorLikeRaw a, SingI dType, MonadThrow m) =>
SGradient gradient ->
SLayout layout ->
SDevice device ->
a ->
m (Tensor gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw SGradient gradient
gradient' SLayout layout
layout SDevice device
device a
x = do
[Int]
dims' <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Tensor gradient layout Any ('DataType dType) ('Shape dims)
t <- forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.empty_lo [Int]
dims' TensorOptions
opts
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) a.
Tensor gradient layout device dataType shape
-> (Ptr () -> IO a) -> IO a
withTensor Tensor gradient layout Any ('DataType dType) ('Shape dims)
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr ->
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
0 [Int]
dims' a
x
forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDevice device
-> Tensor gradient layout device' dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetDevice SDevice device
device Tensor gradient layout Any ('DataType dType) ('Shape dims)
t
where
opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient' SLayout layout
layout (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) (forall {k} (a :: k). SingI a => Sing a
sing @('DataType dType))
fromTensorRaw ::
forall gradient layout device a dType dims.
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) ->
a
fromTensorRaw :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw Tensor gradient layout device ('DataType dType) ('Shape dims)
t =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(device' :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDevice device
-> Tensor gradient layout device' dataType shape
-> m (Tensor gradient layout device dataType shape)
sSetDevice (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) Tensor gradient layout device ('DataType dType) ('Shape dims)
t
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) a.
Tensor gradient layout device dataType shape
-> (Ptr () -> IO a) -> IO a
withTensor (\Ptr ()
ptr -> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
0 forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape
-> [Dim String Integer]
getDims Tensor gradient layout device ('DataType dType) ('Shape dims)
t)
instance TensorLike Bool 'Bool '[] where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Bool
-> m (Tensor gradient layout device ('DataType 'Bool) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Bool) ('Shape '[])
-> Bool
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance TensorLikeRaw Bool where
guessDim :: Maybe Bool -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Bool -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Bool
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff @Word8 (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall a. Eq a => a -> a -> Bool
== Word8
1)
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Bool [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Bool -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Bool
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff @Word8 (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset (forall a. Num a => Bool -> a
fromBool Bool
x)
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Bool
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
x
instance TensorLike Int 'Int64 '[] where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Int
-> m (Tensor
gradient layout device ('DataType 'Int64) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Int64) ('Shape '[])
-> Int
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance TensorLikeRaw Int where
guessDim :: Maybe Int -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Int -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a. Alternative f => f a
empty
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Int
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Int [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Int -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Int
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset Int
x
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Int
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x
instance TensorLike Float 'Float '[] where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Float
-> m (Tensor
gradient layout device ('DataType 'Float) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Float) ('Shape '[])
-> Float
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance TensorLikeRaw Float where
guessDim :: Maybe Float -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Float -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a. Alternative f => f a
empty
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Float
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Float [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Float -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Float
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset Float
x
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Float
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Float
x
instance TensorLike Double 'Double '[] where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Double
-> m (Tensor
gradient layout device ('DataType 'Double) ('Shape '[]))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType 'Double) ('Shape '[])
-> Double
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance TensorLikeRaw Double where
guessDim :: Maybe Double -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe Double -> m [Int]
guessInnerDims = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a. Alternative f => f a
empty
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO Double
tensorPeekElemOff Ptr ()
ptr Int
offset [] = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @Double [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Double -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [] Double
x = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset Double
x
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Double
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
x
data DimMismatchError = DimMismatchError {DimMismatchError -> [Int]
dmeFirst :: [Int], DimMismatchError -> [Int]
dmeOther :: [Int]}
deriving (Int -> DimMismatchError -> ShowS
[DimMismatchError] -> ShowS
DimMismatchError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DimMismatchError] -> ShowS
$cshowList :: [DimMismatchError] -> ShowS
show :: DimMismatchError -> String
$cshow :: DimMismatchError -> String
showsPrec :: Int -> DimMismatchError -> ShowS
$cshowsPrec :: Int -> DimMismatchError -> ShowS
Show, DimMismatchError -> DimMismatchError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DimMismatchError -> DimMismatchError -> Bool
$c/= :: DimMismatchError -> DimMismatchError -> Bool
== :: DimMismatchError -> DimMismatchError -> Bool
$c== :: DimMismatchError -> DimMismatchError -> Bool
Eq)
instance Exception DimMismatchError where
displayException :: DimMismatchError -> String
displayException DimMismatchError {[Int]
dmeOther :: [Int]
dmeFirst :: [Int]
dmeOther :: DimMismatchError -> [Int]
dmeFirst :: DimMismatchError -> [Int]
..} =
String
"When converting to a tensor, all elements on the same dimension must have the same shape, "
forall a. Semigroup a => a -> a -> a
<> String
"but the first element has shape "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
dmeFirst
forall a. Semigroup a => a -> a -> a
<> String
" while another element has shape "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Int]
dmeOther
forall a. Semigroup a => a -> a -> a
<> String
"."
checkDims :: MonadThrow m => [Int] -> [Int] -> m ()
checkDims :: forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
firstDims [Int]
otherDims = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Int]
firstDims forall a. Eq a => a -> a -> Bool
/= [Int]
otherDims) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> DimMismatchError
DimMismatchError [Int]
firstDims [Int]
otherDims
instance
( TensorLike a dType dims,
TensorLike b dType dims',
TensorLikeRaw a,
TensorLikeRaw b,
SingI dType,
SGetDims dimsOut,
'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape (dims <+> dims')) ('Dim ('Name "*") ('Size 2))
) =>
TensorLike (a, b) dType dimsOut
where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> (a, b)
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> (a, b)
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance (TensorLikeRaw a, TensorLikeRaw b) => TensorLikeRaw (a, b) where
guessDim :: Maybe (a, b) -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
2
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (a, b) -> m [Int]
guessInnerDims (forall (f :: * -> *) a b. Functor f => f (a, b) -> (f a, f b)
unzip -> (Maybe a
x, Maybe b
y)) = do
[Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x
[Int]
yDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe b
y
forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims [Int]
yDims
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (a, b)
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
2 : [Int]
innerDims) =
(,)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
innerDims
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @(a, b) [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> (a, b) -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
2 : [Int]
innerDims) (a
x, b
y) = do
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
innerDims a
x
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims b
y
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' (a, b)
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b)
x
instance
( TensorLike a dType dims,
TensorLike b dType dims',
TensorLike c dType dims',
TensorLikeRaw a,
TensorLikeRaw b,
TensorLikeRaw c,
SingI dType,
SGetDims dimsOut,
'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape (dims <+> dims')) ('Dim ('Name "*") ('Size 3))
) =>
TensorLike (a, b, c) dType dimsOut
where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> (a, b, c)
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> (a, b, c)
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
unzip3 :: Functor f => f (a, b, c) -> (f a, f b, f c)
unzip3 :: forall (f :: * -> *) a b c.
Functor f =>
f (a, b, c) -> (f a, f b, f c)
unzip3 f (a, b, c)
xyz =
( (\(a
x, b
_, c
_) -> a
x) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (a, b, c)
xyz,
(\(a
_, b
y, c
_) -> b
y) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (a, b, c)
xyz,
(\(a
_, b
_, c
z) -> c
z) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (a, b, c)
xyz
)
instance (TensorLikeRaw a, TensorLikeRaw b, TensorLikeRaw c) => TensorLikeRaw (a, b, c) where
guessDim :: Maybe (a, b, c) -> Maybe Int
guessDim = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
2
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (a, b, c) -> m [Int]
guessInnerDims (forall (f :: * -> *) a b c.
Functor f =>
f (a, b, c) -> (f a, f b, f c)
unzip3 -> (Maybe a
x, Maybe b
y, Maybe c
z)) = do
[Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe a
x
[Int]
yDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe b
y
[Int]
zDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims Maybe c
z
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims) [[Int]
yDims, [Int]
zDims]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (a, b, c)
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
3 : [Int]
innerDims) =
(,,)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
innerDims
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @(a, b) [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> (a, b, c) -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
3 : [Int]
innerDims) (a
x, b
y, c
z) = do
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
innerDims a
x
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
width) [Int]
innerDims b
y
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims c
z
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' (a, b, c)
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, c)
x
instance
( TensorLike a dType dims,
TensorLikeRaw a,
SingI dType,
SGetDims dimsOut,
'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape dims) ('Dim ('Name "*") 'UncheckedSize)
) =>
TensorLike [a] dType dimsOut
where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> [a]
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> [a]
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance TensorLikeRaw a => TensorLikeRaw [a] where
guessDim :: Maybe [a] -> Maybe Int
guessDim = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 forall (t :: * -> *) a. Foldable t => t a -> Int
length
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe [a] -> m [Int]
guessInnerDims =
(forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. [a] -> Maybe (NonEmpty a)
nonEmpty) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
Maybe (NonEmpty a)
Nothing -> forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims @a forall (f :: * -> *) a. Alternative f => f a
empty
Just (a
x :| [a]
xs) -> do
[Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure) [a]
xs
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO [a]
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. Int
d forall a. Num a => a -> a -> a
- Int
1] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @[a] [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> [a] -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) [a]
xs =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. Int
d forall a. Num a => a -> a -> a
- Int
1] [a]
xs) forall a b. (a -> b) -> a -> b
$ \(Int
i, a
x) ->
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims a
x
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' [a]
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
x
instance
( TensorLike a dType dims,
TensorLikeRaw a,
SingI dType,
SGetDims dimsOut,
'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape dims) ('Dim ('Name "*") 'UncheckedSize)
) =>
TensorLike (V.Vector a) dType dimsOut
where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Vector a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> Vector a
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance
TensorLikeRaw a =>
TensorLikeRaw (V.Vector a)
where
guessDim :: Maybe (Vector a) -> Maybe Int
guessDim = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 forall (t :: * -> *) a. Foldable t => t a -> Int
length
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (Vector a) -> m [Int]
guessInnerDims =
(forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Vector a -> Maybe (a, Vector a)
V.uncons) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
Maybe (a, Vector a)
Nothing -> forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims @a forall (f :: * -> *) a. Alternative f => f a
empty
Just (a
x, Vector a
xs) -> do
[Int]
xDims <- forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall (m :: * -> *). MonadThrow m => [Int] -> [Int] -> m ()
checkDims [Int]
xDims forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure) Vector a
xs
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
xDims
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (Vector a)
tensorPeekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a. Enum a => a -> a -> Vector a
V.enumFromTo Int
0 (Int
d forall a. Num a => a -> a -> a
- Int
1)) forall a b. (a -> b) -> a -> b
$ \Int
i -> do
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPeekElemOff Ptr ()
_ Int
_ [Int]
dims' = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError @(V.Vector a) [Int]
dims' forall (f :: * -> *) a. Alternative f => f a
empty
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Vector a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
innerDims) Vector a
xs = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. Vector a -> Vector b -> Vector (a, b)
V.zip (forall a. Enum a => a -> a -> Vector a
V.enumFromTo Int
0 (Int
d forall a. Num a => a -> a -> a
- Int
1)) Vector a
xs) forall a b. (a -> b) -> a -> b
$ \(Int
i, a
x) ->
forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
innerDims a
x
where
width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
innerDims
tensorPokeElemOff Ptr ()
_ Int
_ [Int]
dims' Vector a
x = forall a (m :: * -> *) b.
(TensorLikeRaw a, MonadThrow m) =>
[Int] -> Maybe a -> m b
unexpectedDimsError [Int]
dims' forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
x
instance
( KnownNat n,
TensorLike a dType dims,
TensorLikeRaw a,
SingI dType,
SGetDims dimsOut,
'Shape dimsOut ~ InsertDimF ('SelectDim ('ByIndex 0)) ('Shape dims) ('Dim ('Name "*") ('Size n))
) =>
TensorLike (SV.Vector n a) dType dimsOut
where
sToTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
MonadThrow m =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> Vector n a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dimsOut))
sToTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(TensorLike a dType dims, TensorLikeRaw a, SingI dType,
MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensorRaw
fromTensor :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
Tensor gradient layout device ('DataType dType) ('Shape dimsOut)
-> Vector n a
fromTensor = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) a
(dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]).
(TensorLike a dType dims, TensorLikeRaw a, SGetDims dims) =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensorRaw
instance
( KnownNat n,
TensorLikeRaw a
) =>
TensorLikeRaw (SV.Vector n a)
where
guessDim :: Maybe (Vector n a) -> Maybe Int
guessDim = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 forall (t :: * -> *) a. Foldable t => t a -> Int
length
guessInnerDims :: forall (m :: * -> *). MonadThrow m => Maybe (Vector n a) -> m [Int]
guessInnerDims = forall a (m :: * -> *).
(TensorLikeRaw a, MonadThrow m) =>
Maybe a -> m [Int]
guessInnerDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a (n :: Nat). KnownNat n => Vector n a -> Vector a
SV.SomeSized
tensorPeekElemOff :: Ptr () -> Int -> [Int] -> IO (Vector n a)
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
dims' = forall (v :: * -> *) (n :: Nat) a. v a -> Vector v n a
SVI.Vector forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> IO a
tensorPeekElemOff Ptr ()
ptr Int
offset [Int]
dims'
tensorPokeElemOff :: Ptr () -> Int -> [Int] -> Vector n a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
dims' = forall a. TensorLikeRaw a => Ptr () -> Int -> [Int] -> a -> IO ()
tensorPokeElemOff Ptr ()
ptr Int
offset [Int]
dims' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (n :: Nat). KnownNat n => Vector n a -> Vector a
SV.SomeSized
sSetTensorOptions ::
forall gradient layout device dataType gradientFrom layoutFrom deviceFrom dataTypeFrom shape.
SGradient gradient ->
SLayout layout ->
SDevice device ->
SDataType dataType ->
Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape ->
IO (Tensor gradient layout device dataType shape)
sSetTensorOptions :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(gradientFrom :: Gradient RequiresGradient)
(layoutFrom :: Layout LayoutType)
(deviceFrom :: Device (DeviceType Nat))
(dataTypeFrom :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
-> IO (Tensor gradient layout device dataType shape)
sSetTensorOptions SGradient gradient
gradient' SLayout layout
layout SDevice device
device SDataType dataType
dataType Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
t =
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
t TensorOptions
opts Bool
nonBlocking Bool
copy
where
opts :: TensorOptions
opts = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> TensorOptions
tensorOptions SGradient gradient
gradient' SLayout layout
layout SDevice device
device SDataType dataType
dataType
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
setTensorOptions ::
forall gradient layout device dataType gradientFrom layoutFrom deviceFrom dataTypeFrom shape.
( SingI gradient,
SingI layout,
SingI device,
SingI dataType
) =>
Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape ->
IO (Tensor gradient layout device dataType shape)
setTensorOptions :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(gradientFrom :: Gradient RequiresGradient)
(layoutFrom :: Layout LayoutType)
(deviceFrom :: Device (DeviceType Nat))
(dataTypeFrom :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI gradient, SingI layout, SingI device, SingI dataType) =>
Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
-> IO (Tensor gradient layout device dataType shape)
setTensorOptions = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(gradientFrom :: Gradient RequiresGradient)
(layoutFrom :: Layout LayoutType)
(deviceFrom :: Device (DeviceType Nat))
(dataTypeFrom :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> Tensor gradientFrom layoutFrom deviceFrom dataTypeFrom shape
-> IO (Tensor gradient layout device dataType shape)
sSetTensorOptions (forall {k} (a :: k). SingI a => Sing a
sing @gradient) (forall {k} (a :: k). SingI a => Sing a
sing @layout) (forall {k} (a :: k). SingI a => Sing a
sing @device) (forall {k} (a :: k). SingI a => Sing a
sing @dataType)