{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
module Torch.Typed.Tensor where
import Control.Arrow
import Control.Category
import Data.Finite
import Data.Kind
( Constraint,
Type,
)
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.ForeignPtr
import Foreign.Storable
import GHC.Exts
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D hiding (select)
import Torch.HList
import Torch.Internal.Cast
import Torch.Internal.Class
( Castable (..),
CppTuple2 (..),
CppTuple3 (..),
CppTuple4 (..),
)
import qualified Torch.Internal.Type as ATen
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Auxiliary
import Prelude hiding (id, (.))
class KnownShape (shape :: [Nat]) where
shapeVal :: [Int]
instance KnownShape '[] where
shapeVal :: [Int]
shapeVal = []
instance (KnownNat h, KnownShape t) => KnownShape (h ': t) where
shapeVal :: [Int]
shapeVal = forall (n :: Nat). KnownNat n => Int
natValI @h forall a. a -> [a] -> [a]
: forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @t
getFiniteI :: Finite n -> Int
getFiniteI :: forall (n :: Nat). Finite n -> Int
getFiniteI = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (n :: Nat). Finite n -> Integer
getFinite
class KnownDType (dtype :: D.DType) where
dtypeVal :: D.DType
instance KnownDType 'D.Bool where
dtypeVal :: DType
dtypeVal = DType
D.Bool
instance KnownDType 'D.UInt8 where
dtypeVal :: DType
dtypeVal = DType
D.UInt8
instance KnownDType 'D.Int8 where
dtypeVal :: DType
dtypeVal = DType
D.Int8
instance KnownDType 'D.Int16 where
dtypeVal :: DType
dtypeVal = DType
D.Int16
instance KnownDType 'D.Int32 where
dtypeVal :: DType
dtypeVal = DType
D.Int32
instance KnownDType 'D.Int64 where
dtypeVal :: DType
dtypeVal = DType
D.Int64
instance KnownDType 'D.Half where
dtypeVal :: DType
dtypeVal = DType
D.Half
instance KnownDType 'D.Float where
dtypeVal :: DType
dtypeVal = DType
D.Float
instance KnownDType 'D.Double where
dtypeVal :: DType
dtypeVal = DType
D.Double
type family ComputeDType (dtype' :: dtype) :: D.DType where
ComputeDType Bool = D.Bool
ComputeDType D.Bool = D.Bool
ComputeDType D.UInt8 = D.UInt8
ComputeDType D.Int8 = D.Int8
ComputeDType D.Int16 = D.Int16
ComputeDType D.Int32 = D.Int32
ComputeDType Int = D.Int64
ComputeDType D.Int64 = D.Int64
ComputeDType Float = D.Float
ComputeDType D.Float = D.Float
ComputeDType Double = D.Double
ComputeDType D.Double = D.Double
ComputeDType dtype' = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype')
class KnownDevice (device :: (D.DeviceType, Nat)) where
deviceVal :: D.Device
instance (KnownNat n) => KnownDevice '( 'D.CPU, n) where
deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)
instance (KnownNat n) => KnownDevice '( 'D.CUDA, n) where
deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CUDA (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)
type Size = Type -> Type
type Shape = [Type -> Type]
type family ToNat (shape :: Size) :: Nat where
ToNat (S1 ('MetaSel _ _ _ _) f) = ToNat f
ToNat (D1 _ f) = ToNat f
ToNat (C1 _ f) = ToNat f
ToNat (l :*: r) = ToNat l + ToNat r
ToNat (l :+: r) = If (ToNat l <=? ToNat r) (ToNat r) (ToNat l)
ToNat (K1 R (Vector n _)) = n
ToNat (K1 _ _) = 1
ToNat U1 = 1
ToNat (Vector n) = n
ToNat a = ToNat (Rep (a ()))
type family ToNats (shape :: Shape) :: [Nat] where
ToNats '[] = '[]
ToNats (x ': xs) = ToNat x ': ToNats xs
type family FromNat (shape :: Nat) :: Size where
FromNat n = Vector n
type family FromNats (shape :: [Nat]) :: Shape where
FromNats '[] = '[]
FromNats (x ': xs) = FromNat x ': FromNats xs
class Unnamed t where
type UTShape t :: [Nat]
type UTDevice t :: (D.DeviceType, Nat)
type UTDType t :: D.DType
toUnnamed ::
forall device dtype shape.
IsUnnamed t device dtype shape =>
t ->
Tensor device dtype shape
fromUnnamed ::
forall device dtype shape.
IsUnnamed t device dtype shape =>
Tensor device dtype shape ->
t
toDynamic ::
t -> D.Tensor
type family IsUnnamed t (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) :: Constraint where
IsUnnamed t device dtype shape =
( Unnamed t,
device ~ (UTDevice t),
dtype ~ (UTDType t),
shape ~ (UTShape t)
)
instance Unnamed (Tensor device dtype shape) where
type UTShape (Tensor device dtype shape) = shape
type UTDevice (Tensor device dtype shape) = device
type UTDType (Tensor device dtype shape) = dtype
toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
toUnnamed = forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
fromUnnamed = forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
toDynamic :: Tensor device dtype shape -> Tensor
toDynamic (UnsafeMkTensor Tensor
t) = Tensor
t
data Tensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) where
UnsafeMkTensor :: forall device dtype shape. D.Tensor -> Tensor device dtype shape
type CPUTensor = Tensor '( 'D.CPU, 0)
type CUDATensor deviceIndex = Tensor '( 'D.CUDA, deviceIndex)
data UnknownShapeTensor device dtype = forall shape. UnknownShapeTensor (Tensor device dtype shape)
type family ComputeHaskellType (dtype :: D.DType) :: Type where
ComputeHaskellType D.Bool = Bool
ComputeHaskellType D.Int64 = Int
ComputeHaskellType D.Float = Float
ComputeHaskellType D.Double = Double
ComputeHaskellType dtype = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype)
type family ComputeItemType (ty :: Type) (shape :: [Nat]) :: Type where
ComputeItemType _ '[] = TypeError (Text "Scalars are not supported")
ComputeItemType ty (_ ': '[]) = ty
ComputeItemType ty (_ ': h ': t) = [ComputeItemType ty (h ': t)]
instance
( D.TensorLike [ComputeItemType (ComputeHaskellType dtype) shape],
KnownDevice device,
KnownShape shape
) =>
IsList (Maybe (Tensor device dtype shape))
where
type Item (Maybe (Tensor device dtype shape)) = ComputeItemType (ComputeHaskellType dtype) shape
fromList :: [Item (Maybe (Tensor device dtype shape))]
-> Maybe (Tensor device dtype shape)
fromList [Item (Maybe (Tensor device dtype shape))]
xs = do
[Int]
shapeXs <- forall a. TensorLike a => a -> Maybe [Int]
D._deepDims [Item (Maybe (Tensor device dtype shape))]
xs
if forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape forall a. Eq a => a -> a -> Bool
== [Int]
shapeXs
then forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor forall a b. (a -> b) -> a -> b
$ [Item (Maybe (Tensor device dtype shape))]
xs
else forall a. Maybe a
Nothing
toList :: Maybe (Tensor device dtype shape)
-> [Item (Maybe (Tensor device dtype shape))]
toList Maybe (Tensor device dtype shape)
Nothing = []
toList (Just Tensor device dtype shape
t) = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU Int16
0) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape
t
instance KnownDevice device => Num (Tensor device dtype shape) where
+ :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(+) Tensor device dtype shape
a Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Num a => a -> a -> a
+ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
(-) Tensor device dtype shape
a Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Num a => a -> a -> a
- forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
* :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(*) Tensor device dtype shape
a Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Num a => a -> a -> a
* forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
negate :: Tensor device dtype shape -> Tensor device dtype shape
negate Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
abs :: Tensor device dtype shape -> Tensor device dtype shape
abs Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
signum :: Tensor device dtype shape -> Tensor device dtype shape
signum Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
signum forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
fromInteger :: Integer -> Tensor device dtype shape
fromInteger Integer
i = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Int forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger @Int Integer
i
instance KnownDevice device => Fractional (Tensor device dtype shape) where
Tensor device dtype shape
a / :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
/ Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Fractional a => a -> a -> a
/ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
recip :: Tensor device dtype shape -> Tensor device dtype shape
recip Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => a -> a
recip forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
fromRational :: Rational -> Tensor device dtype shape
fromRational Rational
i = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Float forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => Rational -> a
fromRational @Float Rational
i
instance Show (Tensor device dtype shape) where
show :: Tensor device dtype shape -> String
show (UnsafeMkTensor Tensor
dynamic) = forall a. Show a => a -> String
show Tensor
dynamic
class TensorOptions (shape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where
optionsRuntimeShape :: [Int]
optionsRuntimeDType :: D.DType
optionsRuntimeDevice :: D.Device
instance (KnownDType dtype, KnownDevice device) => TensorOptions '[] dtype device where
optionsRuntimeShape :: [Int]
optionsRuntimeShape = []
optionsRuntimeDType :: DType
optionsRuntimeDType = forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype
optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device
instance (KnownNat h, TensorOptions t dtype device) => TensorOptions (h ': t) dtype device where
optionsRuntimeShape :: [Int]
optionsRuntimeShape = forall (n :: Nat). KnownNat n => Int
natValI @h forall a. a -> [a] -> [a]
: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @t @dtype @device
optionsRuntimeDType :: DType
optionsRuntimeDType = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @t @dtype @device
optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @t @dtype @device
type family All (pred :: a -> Constraint) (l :: [a]) :: Constraint where
All _ '[] = ()
All pred (h ': t) = (pred h, All pred t)
data SomeShape where
SomeShape :: forall (shape :: [Nat]). KnownShape shape => Proxy shape -> SomeShape
someShape :: [Int] -> SomeShape
someShape :: [Int] -> SomeShape
someShape [] = forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @'[]
someShape (Int
h : [Int]
t) = case Integer -> Maybe SomeNat
someNatVal (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h) of
Maybe SomeNat
Nothing -> forall a. HasCallStack => String -> a
error String
"Negative dimension in someShape!"
(Just (SomeNat (Proxy n
Proxy :: Proxy ht))) -> case [Int] -> SomeShape
someShape [Int]
t of
(SomeShape (Proxy shape
Proxy :: Proxy tt)) -> forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(ht ': tt)
data SomeDType where
SomeDType :: forall (dtype :: D.DType). KnownDType dtype => Proxy dtype -> SomeDType
someDType :: D.DType -> SomeDType
someDType :: DType -> SomeDType
someDType DType
D.Bool = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Bool
someDType DType
D.UInt8 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.UInt8
someDType DType
D.Int8 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int8
someDType DType
D.Int16 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int16
someDType DType
D.Int32 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int32
someDType DType
D.Int64 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int64
someDType DType
D.Half = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Half
someDType DType
D.Float = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Float
someDType DType
D.Double = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Double
data SomeDevice where
SomeDevice :: forall (device :: (D.DeviceType, Nat)). KnownDevice device => Proxy device -> SomeDevice
someDevice :: D.Device -> SomeDevice
someDevice :: Device -> SomeDevice
someDevice D.Device {Int16
DeviceType
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
deviceIndex :: Int16
deviceType :: DeviceType
..} = case Integer -> Maybe SomeNat
someNatVal (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceIndex) of
Maybe SomeNat
Nothing -> forall a. HasCallStack => String -> a
error String
"Negative device index in someDevice!"
Just (SomeNat (Proxy n
Proxy :: Proxy n)) -> case DeviceType
deviceType of
DeviceType
D.CPU -> forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @'( 'D.CPU, n)
DeviceType
D.CUDA -> forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @'( 'D.CUDA, n)
withTensor ::
D.Tensor ->
( forall shape dtype device.
( KnownDevice device,
KnownDType dtype,
KnownShape shape
) =>
Tensor device dtype shape ->
r
) ->
r
withTensor :: forall r.
Tensor
-> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r)
-> r
withTensor Tensor
untypedTensor forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
(SomeShape (Proxy shape
Proxy :: Proxy shape)) -> case DType -> SomeDType
someDType (Tensor -> DType
D.dtype Tensor
untypedTensor) of
(SomeDType (Proxy dtype
Proxy :: Proxy dtype)) -> case Device -> SomeDevice
someDevice (Tensor -> Device
D.device Tensor
untypedTensor) of
(SomeDevice (Proxy device
Proxy :: Proxy device)) -> forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor
withTensorShape ::
forall device dtype r.
( KnownDevice device,
KnownDType dtype
) =>
D.Tensor ->
( forall shape.
KnownShape shape =>
Tensor device dtype shape ->
r
) ->
r
withTensorShape :: forall (device :: (DeviceType, Nat)) (dtype :: DType) r.
(KnownDevice device, KnownDType dtype) =>
Tensor
-> (forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r)
-> r
withTensorShape Tensor
untypedTensor forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
(SomeShape (Proxy shape
Proxy :: Proxy shape)) -> forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor
type family ComputeBroadcast (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
ComputeBroadcast '[] reversedShape = Just reversedShape
ComputeBroadcast reversedShape '[] = Just reversedShape
ComputeBroadcast (h ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
ComputeBroadcast (h ': t) (1 ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
ComputeBroadcast (1 ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
ComputeBroadcast _ _ = Nothing
type family CheckBroadcast (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
CheckBroadcast shape shape' Nothing =
TypeError
( Text "The shapes "
:<>: ShowType shape
:<>: Text " and "
:<>: ShowType shape'
:<>: Text " cannot be broadcast"
)
CheckBroadcast _ _ (Just result) = (Reverse result)
type Broadcast shape shape' =
CheckBroadcast
shape
shape'
( ComputeBroadcast
(Reverse shape)
(Reverse shape')
)
type family BasicArithmeticDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
BasicArithmeticDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
BasicArithmeticDTypeIsValid '( 'D.CUDA, _) dtype = ()
BasicArithmeticDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
add,
sub,
mul,
div ::
forall shape'' shape shape' dtype dtype' dtype'' device.
( dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype''
) =>
Tensor device dtype shape ->
Tensor device dtype' shape' ->
Tensor device dtype'' shape''
add :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
add Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.add (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
sub :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
sub Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.sub (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
mul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.mul (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
div :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
div Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.div (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
type family ComparisonDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
ComparisonDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
ComparisonDTypeIsValid '( 'D.CUDA, _) dtype = ()
ComparisonDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
gt,
lt,
ge,
le,
eq,
ne,
(>.),
(<.),
(>=.),
(<=.),
(==.),
(/=.) ::
forall shape'' shape shape' dtype dtype' device.
( shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype'
) =>
Tensor device dtype shape ->
Tensor device dtype' shape' ->
Tensor device 'D.Bool shape''
gt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.gt (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
lt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.lt (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ge :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ge (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
le :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.le (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
eq :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.eq (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ne :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ne (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
>. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt
<. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt
>=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>=.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge
<=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<=.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le
==. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(==.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq
/=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(/=.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne
type family ComputeMatMul (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
ComputeMatMul (k ': '[]) (k ': '[]) = Just '[]
ComputeMatMul (k ': '[]) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (ComputeBroadcast '[] reversedBroadcastShape')
ComputeMatMul (k ': n ': reversedBroadcastShape) (k ': '[]) = AppendToMaybe n (ComputeBroadcast '[] reversedBroadcastShape)
ComputeMatMul (k ': n ': reversedBroadcastShape) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (AppendToMaybe n (ComputeBroadcast reversedBroadcastShape reversedBroadcastShape'))
type family CheckMatMul (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
CheckMatMul shape shape' Nothing =
TypeError
( Text "The shapes "
:<>: ShowType shape
:<>: Text " and "
:<>: ShowType shape'
:<>: Text " are not compatible with matrix multiplication"
)
CheckMatMul _ _ (Just result) = (Reverse result)
type MatMul shape shape' = CheckMatMul shape shape' (ComputeMatMul (Reverse shape) (Reverse shape'))
type family MatMulDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
MatMulDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
MatMulDTypeIsValid '( 'D.CUDA, deviceIndex) dtype = DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype
MatMulDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
matmul ::
forall shape'' shape shape' dtype device.
( shape'' ~ MatMul shape shape',
MatMulDTypeIsValid device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape' ->
Tensor device dtype shape''
matmul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype shape
a Tensor device dtype shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.matmul (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape'
b)
select ::
forall dim idx shape' shape dtype device.
( KnownNat dim,
KnownNat idx,
InRange shape dim idx,
shape' ~ Remove shape dim
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
select :: forall (dim :: Nat) (idx :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownNat idx, InRange shape dim idx,
shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
select Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (n :: Nat). KnownNat n => Int
natValI @idx) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
selectIdx ::
forall dim n shape' shape dtype device.
( KnownNat dim,
n ~ Index shape dim,
shape' ~ Remove shape dim
) =>
Tensor device dtype shape ->
Finite n ->
Tensor device dtype shape'
selectIdx :: forall (dim :: Nat) (n :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, n ~ Index shape dim, shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Finite n -> Tensor device dtype shape'
selectIdx Tensor device dtype shape
t Finite n
idx = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (n :: Nat). Finite n -> Int
getFiniteI Finite n
idx) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
type family Numel (shape :: [Nat]) :: Nat where
Numel '[] = 1
Numel (h ': t) = h * (Numel t)
reshape ::
forall shape' shape dtype device.
( KnownShape shape',
Numel shape ~ Numel shape'
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
reshape :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
D.reshape (forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape') (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
newtype Wrap a = Wrap {forall a. Wrap a -> a
unWrap :: a}
instance {-# OVERLAPS #-} Unnamed t => Castable (Wrap t) D.ATenTensor where
cast :: forall r. Wrap t -> (ATenTensor -> IO r) -> IO r
cast Wrap t
t ATenTensor -> IO r
f =
let (D.Unsafe ATenTensor
aten_tensor) = forall t. Unnamed t => t -> Tensor
toDynamic (forall a. Wrap a -> a
unWrap Wrap t
t)
in ATenTensor -> IO r
f ATenTensor
aten_tensor
uncast :: forall r. ATenTensor -> (Wrap t -> IO r) -> IO r
uncast ATenTensor
aten_tensor Wrap t -> IO r
f = Wrap t -> IO r
f forall a b. (a -> b) -> a -> b
$ forall a. a -> Wrap a
Wrap forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)
instance Castable (NamedTensor device dtype shape) D.ATenTensor where
cast :: forall r.
NamedTensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (FromTensor (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor))) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
uncast :: forall r.
ATenTensor -> (NamedTensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor NamedTensor device dtype shape -> IO r
f = NamedTensor device dtype shape -> IO r
f forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor
instance Castable (Tensor device dtype shape) D.ATenTensor where
cast :: forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor)) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
uncast :: forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor Tensor device dtype shape -> IO r
f = Tensor device dtype shape -> IO r
f forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)
instance Castable [Tensor device dtype shape] (ForeignPtr ATen.TensorList) where
cast :: forall r.
[Tensor device dtype shape]
-> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor device dtype shape]
xs ForeignPtr TensorList -> IO r
f = do
[ATenTensor]
ptr_list <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor device dtype shape
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor device dtype shape]
xs
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ATenTensor]
ptr_list ForeignPtr TensorList -> IO r
f
uncast :: forall r.
ForeignPtr TensorList
-> ([Tensor device dtype shape] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor device dtype shape] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
[Tensor device dtype shape]
tensor_list <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
[Tensor device dtype shape] -> IO r
f [Tensor device dtype shape]
tensor_list
instance KnownNat n => Castable (Vector n (Tensor device dtype shape)) (ForeignPtr ATen.TensorList) where
cast :: forall r.
Vector n (Tensor device dtype shape)
-> (ForeignPtr TensorList -> IO r) -> IO r
cast Vector n (Tensor device dtype shape)
xs ForeignPtr TensorList -> IO r
f = do
[ATenTensor]
ptr_list <- forall (n :: Nat) a. Vector n a -> [a]
V.toList forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor device dtype shape
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) Vector n (Tensor device dtype shape)
xs
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ATenTensor]
ptr_list ForeignPtr TensorList -> IO r
f
uncast :: forall r.
ForeignPtr TensorList
-> (Vector n (Tensor device dtype shape) -> IO r) -> IO r
uncast ForeignPtr TensorList
xs Vector n (Tensor device dtype shape) -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
[Tensor device dtype shape]
tensor_list <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
Just Vector n (Tensor device dtype shape)
xs <- forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) a. KnownNat n => [a] -> Maybe (Vector n a)
V.fromListN [Tensor device dtype shape]
tensor_list
Vector n (Tensor device dtype shape) -> IO r
f Vector n (Tensor device dtype shape)
xs
data TensorListFold = TensorListFold
instance (Castable x D.ATenTensor) => Apply' TensorListFold (x, IO [D.ATenTensor]) (IO [D.ATenTensor]) where
apply' :: TensorListFold -> (x, IO [ATenTensor]) -> IO [ATenTensor]
apply' TensorListFold
_ (x
x, IO [ATenTensor]
mxs) = do
[ATenTensor]
xs <- IO [ATenTensor]
mxs
ATenTensor
x' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast x
x forall (m :: Type -> Type) a. Monad m => a -> m a
return
forall (m :: Type -> Type) a. Monad m => a -> m a
return (ATenTensor
x' forall a. a -> [a] -> [a]
: [ATenTensor]
xs)
data TensorListUnfold = TensorListUnfold
instance Apply TensorListUnfold [D.ATenTensor] (IO HNothing) where
apply :: TensorListUnfold -> [ATenTensor] -> IO HNothing
apply TensorListUnfold
_ [] = forall (f :: Type -> Type) a. Applicative f => a -> f a
pure HNothing
HNothing
instance (Castable x D.ATenTensor) => Apply TensorListUnfold [D.ATenTensor] (IO (HJust (x, [D.ATenTensor]))) where
apply :: TensorListUnfold -> [ATenTensor] -> IO (HJust (x, [ATenTensor]))
apply TensorListUnfold
_ (ATenTensor
x : [ATenTensor]
xs) = do
x
x' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x forall (m :: Type -> Type) a. Monad m => a -> m a
return
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall x. x -> HJust x
HJust (x
x', [ATenTensor]
xs)
instance
( HFoldrM IO TensorListFold [D.ATenTensor] l [D.ATenTensor],
Apply TensorListUnfold [D.ATenTensor] res,
HUnfoldM IO TensorListUnfold res l,
res ~ (HUnfoldMRes IO [D.ATenTensor] l)
) =>
Castable (HList l) [D.ATenTensor]
where
cast :: forall r. HList l -> ([ATenTensor] -> IO r) -> IO r
cast HList l
xs [ATenTensor] -> IO r
f = [ATenTensor] -> IO r
f forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< HList l -> IO [ATenTensor]
go HList l
xs
where
go :: HList l -> IO [D.ATenTensor]
go :: HList l -> IO [ATenTensor]
go HList l
xs = forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM TensorListFold
TensorListFold ([] :: [D.ATenTensor]) HList l
xs
uncast :: forall r. [ATenTensor] -> (HList l -> IO r) -> IO r
uncast [ATenTensor]
xs HList l -> IO r
f = HList l -> IO r
f forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs
where
go :: [D.ATenTensor] -> IO (HList l)
go :: [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs = forall (m :: Type -> Type) f res (xs :: [Type]) a.
(HUnfoldM m f res xs, Apply f a res, res ~ HUnfoldMRes m a xs) =>
f -> a -> m (HList xs)
hunfoldrM TensorListUnfold
TensorListUnfold [ATenTensor]
xs
instance Castable (HList l) [D.ATenTensor] => Castable (HList l) (ForeignPtr ATen.TensorList) where
cast :: forall r. HList l -> (ForeignPtr TensorList -> IO r) -> IO r
cast HList l
xs ForeignPtr TensorList -> IO r
f = do
[ATenTensor]
ts <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList l
xs forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO [ForeignPtr ATen.Tensor]
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ATenTensor]
ts ForeignPtr TensorList -> IO r
f
uncast :: forall r. ForeignPtr TensorList -> (HList l -> IO r) -> IO r
uncast ForeignPtr TensorList
xs HList l -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \([ATenTensor]
ptrList :: [ForeignPtr ATen.Tensor]) -> do
HList l
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ATenTensor]
ptrList forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (HList l)
HList l -> IO r
f HList l
ts
toSparse :: Tensor device dtype shape -> Tensor device dtype shape
toSparse :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toSparse Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toSparse (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
toDense :: Tensor device dtype shape -> Tensor device dtype shape
toDense :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toDense Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toDense (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
toCPU ::
forall device shape dtype.
Tensor device dtype shape ->
CPUTensor dtype shape
toCPU :: forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU Tensor device dtype shape
input = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCPU (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
input)
toCUDA ::
forall device' device shape dtype.
Tensor device dtype shape ->
CUDATensor 0 dtype shape
toCUDA :: forall {k} (device' :: k) (device :: (DeviceType, Nat))
(shape :: [Nat]) (dtype :: DType).
Tensor device dtype shape -> CUDATensor 0 dtype shape
toCUDA Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCUDA (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
toDevice ::
forall device' device dtype shape t t'.
( KnownDevice device',
IsUnnamed t device dtype shape,
Unnamed t',
t' ~ ReplaceDevice'' t device'
) =>
t ->
t'
toDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device') forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic
toDType ::
forall dtype' dtype device shape t t'.
( KnownDType dtype',
IsUnnamed t device dtype shape,
Unnamed t',
t' ~ ReplaceDType'' t dtype'
) =>
t ->
t'
toDType :: forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
toDType = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => DType -> a -> a
D.toType (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype') forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic
dim ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
Int
dim :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> Int
dim t
t = forall (t :: Type -> Type) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device
shape ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
[Int]
shape :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> [Int]
shape t
_ = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device
dtype ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
D.DType
dtype :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> DType
dtype t
_ = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device
device ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
D.Device
device :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> Device
device t
_ = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device
toInt ::
Tensor device dtype shape ->
Int
toInt :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor device dtype shape -> Int
toInt Tensor device dtype shape
t = Tensor -> Int
D.toInt forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
toFloat :: forall device. Tensor device 'D.Float '[] -> Float
toFloat :: forall (device :: (DeviceType, Nat)).
Tensor device 'Float '[] -> Float
toFloat Tensor device 'Float '[]
t = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU forall a b. (a -> b) -> a -> b
$ Tensor device 'Float '[]
t
toDouble :: forall device. Tensor device 'D.Double '[] -> Double
toDouble :: forall (device :: (DeviceType, Nat)).
Tensor device 'Double '[] -> Double
toDouble Tensor device 'Double '[]
t = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU forall a b. (a -> b) -> a -> b
$ Tensor device 'Double '[]
t
toBool :: forall device. Tensor device 'D.Bool '[] -> Bool
toBool :: forall (device :: (DeviceType, Nat)).
Tensor device 'Bool '[] -> Bool
toBool Tensor device 'Bool '[]
t = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[]
t
type family ToDType a :: D.DType where
ToDType Bool = 'D.Bool
ToDType Int = 'D.Int64
ToDType Float = 'D.Float
ToDType Double = 'D.Double
ToDType (f a) = ToDType a
type family ToShape a :: Shape where
ToShape Bool = '[]
ToShape Int = '[]
ToShape Float = '[]
ToShape Double = '[]
ToShape (f a) = f ': ToShape a
type family FindDim (a :: Size) (shape :: Shape) :: Nat where
FindDim a (a ': _) = 0
FindDim a (b ': ax) = 1 + FindDim a ax
FindDim a _ = TypeError (Text "Not find a type:" :<>: ShowType a :<>: Text " in the shape.")
data NamedTensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: Shape) where
FromTensor :: forall device dtype shape' shape. shape ~ ToNats shape' => Tensor device dtype shape -> NamedTensor device dtype shape'
instance Unnamed (NamedTensor device dtype shape) where
type UTShape (NamedTensor device dtype shape) = ToNats shape
type UTDevice (NamedTensor device dtype shape) = device
type UTDType (NamedTensor device dtype shape) = dtype
toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed (FromTensor Tensor device dtype shape
t) = Tensor device dtype shape
t
fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor
toDynamic :: NamedTensor device dtype shape -> Tensor
toDynamic (FromTensor (UnsafeMkTensor Tensor
t)) = Tensor
t
instance (KnownDevice device) => Num (NamedTensor device dtype shape) where
+ :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(+) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Num a => a -> a -> a
+ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
(-) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Num a => a -> a -> a
- forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
* :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(*) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Num a => a -> a -> a
* forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
negate :: NamedTensor device dtype shape -> NamedTensor device dtype shape
negate = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => a -> a
negate forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
abs :: NamedTensor device dtype shape -> NamedTensor device dtype shape
abs = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => a -> a
abs forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
signum :: NamedTensor device dtype shape -> NamedTensor device dtype shape
signum = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => a -> a
signum forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
fromInteger :: Integer -> NamedTensor device dtype shape
fromInteger = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => Integer -> a
fromInteger
instance KnownDevice device => Fractional (NamedTensor device dtype shape) where
NamedTensor device dtype shape
a / :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
/ NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Fractional a => a -> a -> a
/ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
recip :: NamedTensor device dtype shape -> NamedTensor device dtype shape
recip = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Fractional a => a -> a
recip forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
fromRational :: Rational -> NamedTensor device dtype shape
fromRational = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Fractional a => Rational -> a
fromRational
instance Show (NamedTensor device dtype shape) where
show :: NamedTensor device dtype shape -> String
show = forall a. Show a => a -> String
show forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
type family ReplaceDevice'' (tensor :: t) (device :: (D.DeviceType, Nat)) :: t where
ReplaceDevice'' (Tensor device0 dtype shape) device1 = Tensor device1 dtype shape
ReplaceDevice'' (NamedTensor device0 dtype shape) device1 = NamedTensor device1 dtype shape
type family ReplaceDType'' (tensor :: t) (dtype :: D.DType) :: t where
ReplaceDType'' (Tensor device dtype0 shape) dtype1 = Tensor device dtype1 shape
ReplaceDType'' (NamedTensor device dtype0 shape) dtype1 = NamedTensor device dtype1 shape