{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Tensor where

import Control.Arrow
import Control.Category
import Data.Finite
import Data.Kind
  ( Constraint,
    Type,
  )
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.ForeignPtr
import Foreign.Storable
import GHC.Exts
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D hiding (select)
import Torch.HList
import Torch.Internal.Cast
import Torch.Internal.Class
  ( Castable (..),
    CppTuple2 (..),
    CppTuple3 (..),
    CppTuple4 (..),
  )
import qualified Torch.Internal.Type as ATen
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Auxiliary
import Prelude hiding (id, (.))

class KnownShape (shape :: [Nat]) where
  shapeVal :: [Int]

instance KnownShape '[] where
  shapeVal :: [Int]
shapeVal = []

instance (KnownNat h, KnownShape t) => KnownShape (h ': t) where
  shapeVal :: [Int]
shapeVal = forall (n :: Nat). KnownNat n => Int
natValI @h forall a. a -> [a] -> [a]
: forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @t

getFiniteI :: Finite n -> Int
getFiniteI :: forall (n :: Nat). Finite n -> Int
getFiniteI = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (n :: Nat). Finite n -> Integer
getFinite

class KnownDType (dtype :: D.DType) where
  dtypeVal :: D.DType

instance KnownDType 'D.Bool where
  dtypeVal :: DType
dtypeVal = DType
D.Bool

instance KnownDType 'D.UInt8 where
  dtypeVal :: DType
dtypeVal = DType
D.UInt8

instance KnownDType 'D.Int8 where
  dtypeVal :: DType
dtypeVal = DType
D.Int8

instance KnownDType 'D.Int16 where
  dtypeVal :: DType
dtypeVal = DType
D.Int16

instance KnownDType 'D.Int32 where
  dtypeVal :: DType
dtypeVal = DType
D.Int32

instance KnownDType 'D.Int64 where
  dtypeVal :: DType
dtypeVal = DType
D.Int64

instance KnownDType 'D.Half where
  dtypeVal :: DType
dtypeVal = DType
D.Half

instance KnownDType 'D.Float where
  dtypeVal :: DType
dtypeVal = DType
D.Float

instance KnownDType 'D.Double where
  dtypeVal :: DType
dtypeVal = DType
D.Double

type family ComputeDType (dtype' :: dtype) :: D.DType where
  ComputeDType Bool = D.Bool
  ComputeDType D.Bool = D.Bool
  ComputeDType D.UInt8 = D.UInt8
  ComputeDType D.Int8 = D.Int8
  ComputeDType D.Int16 = D.Int16
  ComputeDType D.Int32 = D.Int32
  ComputeDType Int = D.Int64
  ComputeDType D.Int64 = D.Int64
  ComputeDType Float = D.Float
  ComputeDType D.Float = D.Float
  ComputeDType Double = D.Double
  ComputeDType D.Double = D.Double
  ComputeDType dtype' = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype')

class KnownDevice (device :: (D.DeviceType, Nat)) where
  deviceVal :: D.Device

instance (KnownNat n) => KnownDevice '( 'D.CPU, n) where
  deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)

instance (KnownNat n) => KnownDevice '( 'D.CUDA, n) where
  deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CUDA (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)

type Size = Type -> Type

type Shape = [Type -> Type]

type family ToNat (shape :: Size) :: Nat where
  ToNat (S1 ('MetaSel _ _ _ _) f) = ToNat f
  ToNat (D1 _ f) = ToNat f
  ToNat (C1 _ f) = ToNat f
  ToNat (l :*: r) = ToNat l + ToNat r
  ToNat (l :+: r) = If (ToNat l <=? ToNat r) (ToNat r) (ToNat l)
  ToNat (K1 R (Vector n _)) = n
  ToNat (K1 _ _) = 1
  ToNat U1 = 1
  ToNat (Vector n) = n
  ToNat a = ToNat (Rep (a ()))

type family ToNats (shape :: Shape) :: [Nat] where
  ToNats '[] = '[]
  ToNats (x ': xs) = ToNat x ': ToNats xs

type family FromNat (shape :: Nat) :: Size where
  FromNat n = Vector n

type family FromNats (shape :: [Nat]) :: Shape where
  FromNats '[] = '[]
  FromNats (x ': xs) = FromNat x ': FromNats xs

class Unnamed t where
  type UTShape t :: [Nat]
  type UTDevice t :: (D.DeviceType, Nat)
  type UTDType t :: D.DType
  toUnnamed ::
    forall device dtype shape.
    IsUnnamed t device dtype shape =>
    t ->
    Tensor device dtype shape
  fromUnnamed ::
    forall device dtype shape.
    IsUnnamed t device dtype shape =>
    Tensor device dtype shape ->
    t
  toDynamic ::
    t -> D.Tensor

type family IsUnnamed t (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) :: Constraint where
  IsUnnamed t device dtype shape =
    ( Unnamed t,
      device ~ (UTDevice t),
      dtype ~ (UTDType t),
      shape ~ (UTShape t)
    )

instance Unnamed (Tensor device dtype shape) where
  type UTShape (Tensor device dtype shape) = shape
  type UTDevice (Tensor device dtype shape) = device
  type UTDType (Tensor device dtype shape) = dtype
  toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
toUnnamed = forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
  fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
fromUnnamed = forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
  toDynamic :: Tensor device dtype shape -> Tensor
toDynamic (UnsafeMkTensor Tensor
t) = Tensor
t

data Tensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) where
  UnsafeMkTensor :: forall device dtype shape. D.Tensor -> Tensor device dtype shape

type CPUTensor = Tensor '( 'D.CPU, 0)

type CUDATensor deviceIndex = Tensor '( 'D.CUDA, deviceIndex)

data UnknownShapeTensor device dtype = forall shape. UnknownShapeTensor (Tensor device dtype shape)

type family ComputeHaskellType (dtype :: D.DType) :: Type where
  ComputeHaskellType D.Bool = Bool
  ComputeHaskellType D.Int64 = Int
  ComputeHaskellType D.Float = Float
  ComputeHaskellType D.Double = Double
  ComputeHaskellType dtype = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype)

type family ComputeItemType (ty :: Type) (shape :: [Nat]) :: Type where
  ComputeItemType _ '[] = TypeError (Text "Scalars are not supported")
  ComputeItemType ty (_ ': '[]) = ty
  ComputeItemType ty (_ ': h ': t) = [ComputeItemType ty (h ': t)]

instance
  ( D.TensorLike [ComputeItemType (ComputeHaskellType dtype) shape],
    KnownDevice device,
    KnownShape shape
  ) =>
  IsList (Maybe (Tensor device dtype shape))
  where
  type Item (Maybe (Tensor device dtype shape)) = ComputeItemType (ComputeHaskellType dtype) shape
  fromList :: [Item (Maybe (Tensor device dtype shape))]
-> Maybe (Tensor device dtype shape)
fromList [Item (Maybe (Tensor device dtype shape))]
xs = do
    [Int]
shapeXs <- forall a. TensorLike a => a -> Maybe [Int]
D._deepDims [Item (Maybe (Tensor device dtype shape))]
xs
    if forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape forall a. Eq a => a -> a -> Bool
== [Int]
shapeXs
      then forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor forall a b. (a -> b) -> a -> b
$ [Item (Maybe (Tensor device dtype shape))]
xs
      else forall a. Maybe a
Nothing
  toList :: Maybe (Tensor device dtype shape)
-> [Item (Maybe (Tensor device dtype shape))]
toList Maybe (Tensor device dtype shape)
Nothing = []
  toList (Just Tensor device dtype shape
t) = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU Int16
0) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape
t

instance KnownDevice device => Num (Tensor device dtype shape) where
  + :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(+) Tensor device dtype shape
a Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Num a => a -> a -> a
+ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  (-) Tensor device dtype shape
a Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Num a => a -> a -> a
- forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  * :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(*) Tensor device dtype shape
a Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Num a => a -> a -> a
* forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  negate :: Tensor device dtype shape -> Tensor device dtype shape
negate Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  abs :: Tensor device dtype shape -> Tensor device dtype shape
abs Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  signum :: Tensor device dtype shape -> Tensor device dtype shape
signum Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
signum forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  fromInteger :: Integer -> Tensor device dtype shape
fromInteger Integer
i = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Int forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger @Int Integer
i

instance KnownDevice device => Fractional (Tensor device dtype shape) where
  Tensor device dtype shape
a / :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
/ Tensor device dtype shape
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a forall a. Fractional a => a -> a -> a
/ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  recip :: Tensor device dtype shape -> Tensor device dtype shape
recip Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => a -> a
recip forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  fromRational :: Rational -> Tensor device dtype shape
fromRational Rational
i = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Float forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => Rational -> a
fromRational @Float Rational
i

instance Show (Tensor device dtype shape) where
  show :: Tensor device dtype shape -> String
show (UnsafeMkTensor Tensor
dynamic) = forall a. Show a => a -> String
show Tensor
dynamic

class TensorOptions (shape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where
  optionsRuntimeShape :: [Int]
  optionsRuntimeDType :: D.DType
  optionsRuntimeDevice :: D.Device

instance (KnownDType dtype, KnownDevice device) => TensorOptions '[] dtype device where
  optionsRuntimeShape :: [Int]
optionsRuntimeShape = []
  optionsRuntimeDType :: DType
optionsRuntimeDType = forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype
  optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device

instance (KnownNat h, TensorOptions t dtype device) => TensorOptions (h ': t) dtype device where
  optionsRuntimeShape :: [Int]
optionsRuntimeShape = forall (n :: Nat). KnownNat n => Int
natValI @h forall a. a -> [a] -> [a]
: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @t @dtype @device
  optionsRuntimeDType :: DType
optionsRuntimeDType = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @t @dtype @device
  optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @t @dtype @device

--------------------------------------------------------------------------------
-- Untyped -> Typed typecasts
--------------------------------------------------------------------------------

type family All (pred :: a -> Constraint) (l :: [a]) :: Constraint where
  All _ '[] = ()
  All pred (h ': t) = (pred h, All pred t)

data SomeShape where
  SomeShape :: forall (shape :: [Nat]). KnownShape shape => Proxy shape -> SomeShape

someShape :: [Int] -> SomeShape
someShape :: [Int] -> SomeShape
someShape [] = forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @'[]
someShape (Int
h : [Int]
t) = case Integer -> Maybe SomeNat
someNatVal (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h) of
  Maybe SomeNat
Nothing -> forall a. HasCallStack => String -> a
error String
"Negative dimension in someShape!"
  (Just (SomeNat (Proxy n
Proxy :: Proxy ht))) -> case [Int] -> SomeShape
someShape [Int]
t of
    (SomeShape (Proxy shape
Proxy :: Proxy tt)) -> forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(ht ': tt)

data SomeDType where
  SomeDType :: forall (dtype :: D.DType). KnownDType dtype => Proxy dtype -> SomeDType

someDType :: D.DType -> SomeDType
someDType :: DType -> SomeDType
someDType DType
D.Bool = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Bool
someDType DType
D.UInt8 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.UInt8
someDType DType
D.Int8 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int8
someDType DType
D.Int16 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int16
someDType DType
D.Int32 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int32
someDType DType
D.Int64 = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Int64
someDType DType
D.Half = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Half
someDType DType
D.Float = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Float
someDType DType
D.Double = forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @D.Double

data SomeDevice where
  SomeDevice :: forall (device :: (D.DeviceType, Nat)). KnownDevice device => Proxy device -> SomeDevice

someDevice :: D.Device -> SomeDevice
someDevice :: Device -> SomeDevice
someDevice D.Device {Int16
DeviceType
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
deviceIndex :: Int16
deviceType :: DeviceType
..} = case Integer -> Maybe SomeNat
someNatVal (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceIndex) of
  Maybe SomeNat
Nothing -> forall a. HasCallStack => String -> a
error String
"Negative device index in someDevice!"
  Just (SomeNat (Proxy n
Proxy :: Proxy n)) -> case DeviceType
deviceType of
    DeviceType
D.CPU -> forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @'( 'D.CPU, n)
    DeviceType
D.CUDA -> forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @'( 'D.CUDA, n)

withTensor ::
  D.Tensor ->
  ( forall shape dtype device.
    ( KnownDevice device,
      KnownDType dtype,
      KnownShape shape
    ) =>
    Tensor device dtype shape ->
    r
  ) ->
  r
withTensor :: forall r.
Tensor
-> (forall (shape :: [Nat]) (dtype :: DType)
           (device :: (DeviceType, Nat)).
    (KnownDevice device, KnownDType dtype, KnownShape shape) =>
    Tensor device dtype shape -> r)
-> r
withTensor Tensor
untypedTensor forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
  (SomeShape (Proxy shape
Proxy :: Proxy shape)) -> case DType -> SomeDType
someDType (Tensor -> DType
D.dtype Tensor
untypedTensor) of
    (SomeDType (Proxy dtype
Proxy :: Proxy dtype)) -> case Device -> SomeDevice
someDevice (Tensor -> Device
D.device Tensor
untypedTensor) of
      (SomeDevice (Proxy device
Proxy :: Proxy device)) -> forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor

withTensorShape ::
  forall device dtype r.
  ( KnownDevice device,
    KnownDType dtype
  ) =>
  D.Tensor ->
  ( forall shape.
    KnownShape shape =>
    Tensor device dtype shape ->
    r
  ) ->
  r
withTensorShape :: forall (device :: (DeviceType, Nat)) (dtype :: DType) r.
(KnownDevice device, KnownDType dtype) =>
Tensor
-> (forall (shape :: [Nat]).
    KnownShape shape =>
    Tensor device dtype shape -> r)
-> r
withTensorShape Tensor
untypedTensor forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
  -- ToDo: check device/dtype of untyped tensor.
  (SomeShape (Proxy shape
Proxy :: Proxy shape)) -> forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor

--------------------------------------------------------------------------------
-- Broadcast type-level function
--------------------------------------------------------------------------------

type family ComputeBroadcast (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
  ComputeBroadcast '[] reversedShape = Just reversedShape
  ComputeBroadcast reversedShape '[] = Just reversedShape
  ComputeBroadcast (h ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
  ComputeBroadcast (h ': t) (1 ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
  ComputeBroadcast (1 ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
  ComputeBroadcast _ _ = Nothing

type family CheckBroadcast (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
  CheckBroadcast shape shape' Nothing =
    TypeError
      ( Text "The shapes "
          :<>: ShowType shape
          :<>: Text " and "
          :<>: ShowType shape'
          :<>: Text " cannot be broadcast"
      )
  CheckBroadcast _ _ (Just result) = (Reverse result)

type Broadcast shape shape' =
  CheckBroadcast
    shape
    shape'
    ( ComputeBroadcast
        (Reverse shape)
        (Reverse shape')
    )

type family BasicArithmeticDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  BasicArithmeticDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  BasicArithmeticDTypeIsValid '( 'D.CUDA, _) dtype = ()
  BasicArithmeticDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

add,
  sub,
  mul,
  div ::
    forall shape'' shape shape' dtype dtype' dtype'' device.
    ( dtype'' ~ DTypePromotion dtype dtype',
      shape'' ~ Broadcast shape shape',
      BasicArithmeticDTypeIsValid device dtype,
      BasicArithmeticDTypeIsValid device dtype',
      BasicArithmeticDTypeIsValid device dtype''
    ) =>
    Tensor device dtype shape ->
    Tensor device dtype' shape' ->
    Tensor device dtype'' shape''
add :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
add Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.add (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
sub :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
sub Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.sub (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
mul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.mul (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
div :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
div Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.div (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)

type family ComparisonDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  ComparisonDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  ComparisonDTypeIsValid '( 'D.CUDA, _) dtype = ()
  ComparisonDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

gt,
  lt,
  ge,
  le,
  eq,
  ne,
  (>.),
  (<.),
  (>=.),
  (<=.),
  (==.),
  (/=.) ::
    forall shape'' shape shape' dtype dtype' device.
    ( shape'' ~ Broadcast shape shape',
      ComparisonDTypeIsValid device dtype,
      ComparisonDTypeIsValid device dtype'
    ) =>
    Tensor device dtype shape ->
    Tensor device dtype' shape' ->
    Tensor device 'D.Bool shape''
gt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.gt (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
lt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.lt (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ge :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ge (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
le :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.le (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
eq :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.eq (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ne :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne Tensor device dtype shape
a Tensor device dtype' shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ne (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
>. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt
<. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt
>=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>=.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge
<=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<=.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le
==. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(==.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq
/=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(/=.) = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne

type family ComputeMatMul (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
  ComputeMatMul (k ': '[]) (k ': '[]) = Just '[]
  ComputeMatMul (k ': '[]) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (ComputeBroadcast '[] reversedBroadcastShape')
  ComputeMatMul (k ': n ': reversedBroadcastShape) (k ': '[]) = AppendToMaybe n (ComputeBroadcast '[] reversedBroadcastShape)
  ComputeMatMul (k ': n ': reversedBroadcastShape) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (AppendToMaybe n (ComputeBroadcast reversedBroadcastShape reversedBroadcastShape'))

type family CheckMatMul (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
  CheckMatMul shape shape' Nothing =
    TypeError
      ( Text "The shapes "
          :<>: ShowType shape
          :<>: Text " and "
          :<>: ShowType shape'
          :<>: Text " are not compatible with matrix multiplication"
      )
  CheckMatMul _ _ (Just result) = (Reverse result)

type MatMul shape shape' = CheckMatMul shape shape' (ComputeMatMul (Reverse shape) (Reverse shape'))

type family MatMulDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  MatMulDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  MatMulDTypeIsValid '( 'D.CUDA, deviceIndex) dtype = DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype
  MatMulDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | matrix multiplication
-- See https://pytorch.org/docs/stable/torch.html#torch.matmul.
matmul ::
  forall shape'' shape shape' dtype device.
  ( shape'' ~ MatMul shape shape',
    MatMulDTypeIsValid device dtype
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape' ->
  Tensor device dtype shape''
matmul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype shape
a Tensor device dtype shape'
b = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.matmul (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape'
b)

select ::
  forall dim idx shape' shape dtype device.
  ( KnownNat dim,
    KnownNat idx,
    InRange shape dim idx,
    shape' ~ Remove shape dim
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape'
select :: forall (dim :: Nat) (idx :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownNat idx, InRange shape dim idx,
 shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
select Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (n :: Nat). KnownNat n => Int
natValI @idx) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

selectIdx ::
  forall dim n shape' shape dtype device.
  ( KnownNat dim,
    n ~ Index shape dim,
    shape' ~ Remove shape dim
  ) =>
  Tensor device dtype shape ->
  Finite n ->
  Tensor device dtype shape'
selectIdx :: forall (dim :: Nat) (n :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, n ~ Index shape dim, shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Finite n -> Tensor device dtype shape'
selectIdx Tensor device dtype shape
t Finite n
idx = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (n :: Nat). Finite n -> Int
getFiniteI Finite n
idx) (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

type family Numel (shape :: [Nat]) :: Nat where
  Numel '[] = 1
  Numel (h ': t) = h * (Numel t)

-- | reshape
-- >>> t :: CPUTensor 'D.Int64 '[2,3,4] = fromJust [[[111,112,113,114],[121,122,123,124],[131,132,133,134]],[[211,212,213,214],[221,222,223,224],[231,232,233,234]]]
-- >>> t' = reshape @'[24] t
-- >>> toList . Just $ t'
-- [111,112,113,114,121,122,123,124,131,132,133,134,211,212,213,214,221,222,223,224,231,232,233,234]
-- >>> toList . Just $ reshape @'[2,3,4] t'
-- [[[111,112,113,114],[121,122,123,124],[131,132,133,134]],[[211,212,213,214],[221,222,223,224],[231,232,233,234]]]
reshape ::
  forall shape' shape dtype device.
  ( KnownShape shape',
    Numel shape ~ Numel shape'
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape'
reshape :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
D.reshape (forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape') (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- | To avoid overlapped instance for (Unnamed t => Castable t D.ATenTensor)
newtype Wrap a = Wrap {forall a. Wrap a -> a
unWrap :: a}

instance {-# OVERLAPS #-} Unnamed t => Castable (Wrap t) D.ATenTensor where
  cast :: forall r. Wrap t -> (ATenTensor -> IO r) -> IO r
cast Wrap t
t ATenTensor -> IO r
f =
    let (D.Unsafe ATenTensor
aten_tensor) = forall t. Unnamed t => t -> Tensor
toDynamic (forall a. Wrap a -> a
unWrap Wrap t
t)
     in ATenTensor -> IO r
f ATenTensor
aten_tensor
  uncast :: forall r. ATenTensor -> (Wrap t -> IO r) -> IO r
uncast ATenTensor
aten_tensor Wrap t -> IO r
f = Wrap t -> IO r
f forall a b. (a -> b) -> a -> b
$ forall a. a -> Wrap a
Wrap forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)

instance Castable (NamedTensor device dtype shape) D.ATenTensor where
  cast :: forall r.
NamedTensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (FromTensor (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor))) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
  uncast :: forall r.
ATenTensor -> (NamedTensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor NamedTensor device dtype shape -> IO r
f = NamedTensor device dtype shape -> IO r
f forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor

instance Castable (Tensor device dtype shape) D.ATenTensor where
  cast :: forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor)) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
  uncast :: forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor Tensor device dtype shape -> IO r
f = Tensor device dtype shape -> IO r
f forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)

instance Castable [Tensor device dtype shape] (ForeignPtr ATen.TensorList) where
  cast :: forall r.
[Tensor device dtype shape]
-> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor device dtype shape]
xs ForeignPtr TensorList -> IO r
f = do
    [ATenTensor]
ptr_list <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor device dtype shape
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor device dtype shape]
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ATenTensor]
ptr_list ForeignPtr TensorList -> IO r
f
  uncast :: forall r.
ForeignPtr TensorList
-> ([Tensor device dtype shape] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor device dtype shape] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
    [Tensor device dtype shape]
tensor_list <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
    [Tensor device dtype shape] -> IO r
f [Tensor device dtype shape]
tensor_list

instance KnownNat n => Castable (Vector n (Tensor device dtype shape)) (ForeignPtr ATen.TensorList) where
  cast :: forall r.
Vector n (Tensor device dtype shape)
-> (ForeignPtr TensorList -> IO r) -> IO r
cast Vector n (Tensor device dtype shape)
xs ForeignPtr TensorList -> IO r
f = do
    [ATenTensor]
ptr_list <- forall (n :: Nat) a. Vector n a -> [a]
V.toList forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor device dtype shape
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) Vector n (Tensor device dtype shape)
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ATenTensor]
ptr_list ForeignPtr TensorList -> IO r
f
  uncast :: forall r.
ForeignPtr TensorList
-> (Vector n (Tensor device dtype shape) -> IO r) -> IO r
uncast ForeignPtr TensorList
xs Vector n (Tensor device dtype shape) -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
    [Tensor device dtype shape]
tensor_list <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
    Just Vector n (Tensor device dtype shape)
xs <- forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) a. KnownNat n => [a] -> Maybe (Vector n a)
V.fromListN [Tensor device dtype shape]
tensor_list
    Vector n (Tensor device dtype shape) -> IO r
f Vector n (Tensor device dtype shape)
xs

data TensorListFold = TensorListFold

instance (Castable x D.ATenTensor) => Apply' TensorListFold (x, IO [D.ATenTensor]) (IO [D.ATenTensor]) where
  apply' :: TensorListFold -> (x, IO [ATenTensor]) -> IO [ATenTensor]
apply' TensorListFold
_ (x
x, IO [ATenTensor]
mxs) = do
    [ATenTensor]
xs <- IO [ATenTensor]
mxs
    ATenTensor
x' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast x
x forall (m :: Type -> Type) a. Monad m => a -> m a
return
    forall (m :: Type -> Type) a. Monad m => a -> m a
return (ATenTensor
x' forall a. a -> [a] -> [a]
: [ATenTensor]
xs)

data TensorListUnfold = TensorListUnfold

instance Apply TensorListUnfold [D.ATenTensor] (IO HNothing) where
  apply :: TensorListUnfold -> [ATenTensor] -> IO HNothing
apply TensorListUnfold
_ [] = forall (f :: Type -> Type) a. Applicative f => a -> f a
pure HNothing
HNothing

instance (Castable x D.ATenTensor) => Apply TensorListUnfold [D.ATenTensor] (IO (HJust (x, [D.ATenTensor]))) where
  apply :: TensorListUnfold -> [ATenTensor] -> IO (HJust (x, [ATenTensor]))
apply TensorListUnfold
_ (ATenTensor
x : [ATenTensor]
xs) = do
    x
x' <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x forall (m :: Type -> Type) a. Monad m => a -> m a
return
    forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall x. x -> HJust x
HJust (x
x', [ATenTensor]
xs)

instance
  ( HFoldrM IO TensorListFold [D.ATenTensor] l [D.ATenTensor],
    Apply TensorListUnfold [D.ATenTensor] res,
    HUnfoldM IO TensorListUnfold res l,
    res ~ (HUnfoldMRes IO [D.ATenTensor] l)
  ) =>
  Castable (HList l) [D.ATenTensor]
  where
  cast :: forall r. HList l -> ([ATenTensor] -> IO r) -> IO r
cast HList l
xs [ATenTensor] -> IO r
f = [ATenTensor] -> IO r
f forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< HList l -> IO [ATenTensor]
go HList l
xs
    where
      go :: HList l -> IO [D.ATenTensor]
      go :: HList l -> IO [ATenTensor]
go HList l
xs = forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM TensorListFold
TensorListFold ([] :: [D.ATenTensor]) HList l
xs
  uncast :: forall r. [ATenTensor] -> (HList l -> IO r) -> IO r
uncast [ATenTensor]
xs HList l -> IO r
f = HList l -> IO r
f forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs
    where
      go :: [D.ATenTensor] -> IO (HList l)
      go :: [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs = forall (m :: Type -> Type) f res (xs :: [Type]) a.
(HUnfoldM m f res xs, Apply f a res, res ~ HUnfoldMRes m a xs) =>
f -> a -> m (HList xs)
hunfoldrM TensorListUnfold
TensorListUnfold [ATenTensor]
xs

instance Castable (HList l) [D.ATenTensor] => Castable (HList l) (ForeignPtr ATen.TensorList) where
  cast :: forall r. HList l -> (ForeignPtr TensorList -> IO r) -> IO r
cast HList l
xs ForeignPtr TensorList -> IO r
f = do
    [ATenTensor]
ts <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList l
xs forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO [ForeignPtr ATen.Tensor]
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ATenTensor]
ts ForeignPtr TensorList -> IO r
f
  uncast :: forall r. ForeignPtr TensorList -> (HList l -> IO r) -> IO r
uncast ForeignPtr TensorList
xs HList l -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \([ATenTensor]
ptrList :: [ForeignPtr ATen.Tensor]) -> do
    HList l
ts <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ATenTensor]
ptrList forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (HList l)
    HList l -> IO r
f HList l
ts

--------------------------------------------------------------------------------
-- Move tensors
--------------------------------------------------------------------------------

-- TODO: track sparsity in tensor type
toSparse :: Tensor device dtype shape -> Tensor device dtype shape
toSparse :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toSparse Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toSparse (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- TODO: track sparsity in tensor type
toDense :: Tensor device dtype shape -> Tensor device dtype shape
toDense :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toDense Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toDense (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- -- TODO: is this a device?
-- toMKLDNN
--   :: forall device' device shape dtype
--    . Tensor device  dtype shape
--   -> Tensor device' dtype shape
-- toMKLDNN t = UnsafeMkTensor $ D.toMKLDNN (toDynamic t)

-- | move tensor to CPU
-- TODO: can this fail?
toCPU ::
  forall device shape dtype.
  Tensor device dtype shape ->
  CPUTensor dtype shape
toCPU :: forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU Tensor device dtype shape
input = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCPU (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
input)

-- | move tensor to the first CUDA device
-- TODO: what if this fails?
toCUDA ::
  forall device' device shape dtype.
  Tensor device dtype shape ->
  CUDATensor 0 dtype shape
toCUDA :: forall {k} (device' :: k) (device :: (DeviceType, Nat))
       (shape :: [Nat]) (dtype :: DType).
Tensor device dtype shape -> CUDATensor 0 dtype shape
toCUDA Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCUDA (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- | move tensor to device
-- TODO: what if this fails?
toDevice ::
  forall device' device dtype shape t t'.
  ( KnownDevice device',
    IsUnnamed t device dtype shape,
    Unnamed t',
    t' ~ ReplaceDevice'' t device'
  ) =>
  t ->
  t'
toDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device') forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic

-- | change tensor data type
toDType ::
  forall dtype' dtype device shape t t'.
  ( KnownDType dtype',
    IsUnnamed t device dtype shape,
    Unnamed t',
    t' ~ ReplaceDType'' t dtype'
  ) =>
  t ->
  t'
toDType :: forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
toDType = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. HasTypes a Tensor => DType -> a -> a
D.toType (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype') forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic

--------------------------------------------------------------------------------
-- Auxiliary functions for accessing tensor options as values
--------------------------------------------------------------------------------

-- | returns tensor dimension
--   uses compile-time information only
dim ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  Int
dim :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> Int
dim t
t = forall (t :: Type -> Type) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device

-- | returns tensor shape as list
--   uses compile-time information only
shape ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  [Int]
shape :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> [Int]
shape t
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device

-- | returns tensor data type
--   uses compile-time information only
dtype ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  D.DType
dtype :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> DType
dtype t
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device

-- | returns tensor device
--   uses compile-time information only
device ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  D.Device
device :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> Device
device t
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device

--------------------------------------------------------------------------------
-- Auxiliary functions for accessing tensors as values
--------------------------------------------------------------------------------

-- TODO: figure out what device, dtype, and shape we need for this
toInt ::
  Tensor device dtype shape ->
  Int
toInt :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype shape -> Int
toInt Tensor device dtype shape
t = Tensor -> Int
D.toInt forall a b. (a -> b) -> a -> b
$ forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t

toFloat :: forall device. Tensor device 'D.Float '[] -> Float
toFloat :: forall (device :: (DeviceType, Nat)).
Tensor device 'Float '[] -> Float
toFloat Tensor device 'Float '[]
t = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU forall a b. (a -> b) -> a -> b
$ Tensor device 'Float '[]
t

toDouble :: forall device. Tensor device 'D.Double '[] -> Double
toDouble :: forall (device :: (DeviceType, Nat)).
Tensor device 'Double '[] -> Double
toDouble Tensor device 'Double '[]
t = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU forall a b. (a -> b) -> a -> b
$ Tensor device 'Double '[]
t

toBool :: forall device. Tensor device 'D.Bool '[] -> Bool
toBool :: forall (device :: (DeviceType, Nat)).
Tensor device 'Bool '[] -> Bool
toBool Tensor device 'Bool '[]
t = forall a. TensorLike a => Tensor -> a
D.asValue forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. Unnamed t => t -> Tensor
toDynamic forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[]
t

--------------------------------------------------------------------------------
-- NamedTensor
--------------------------------------------------------------------------------

type family ToDType a :: D.DType where
  ToDType Bool = 'D.Bool
  ToDType Int = 'D.Int64
  ToDType Float = 'D.Float
  ToDType Double = 'D.Double
  ToDType (f a) = ToDType a

type family ToShape a :: Shape where
  ToShape Bool = '[]
  ToShape Int = '[]
  ToShape Float = '[]
  ToShape Double = '[]
  ToShape (f a) = f ': ToShape a

type family FindDim (a :: Size) (shape :: Shape) :: Nat where
  FindDim a (a ': _) = 0
  FindDim a (b ': ax) = 1 + FindDim a ax
  FindDim a _ = TypeError (Text "Not find a type:" :<>: ShowType a :<>: Text " in the shape.")

data NamedTensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: Shape) where
  FromTensor :: forall device dtype shape' shape. shape ~ ToNats shape' => Tensor device dtype shape -> NamedTensor device dtype shape'

instance Unnamed (NamedTensor device dtype shape) where
  type UTShape (NamedTensor device dtype shape) = ToNats shape
  type UTDevice (NamedTensor device dtype shape) = device
  type UTDType (NamedTensor device dtype shape) = dtype
  toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed (FromTensor Tensor device dtype shape
t) = Tensor device dtype shape
t
  fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor
  toDynamic :: NamedTensor device dtype shape -> Tensor
toDynamic (FromTensor (UnsafeMkTensor Tensor
t)) = Tensor
t

instance (KnownDevice device) => Num (NamedTensor device dtype shape) where
  + :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(+) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Num a => a -> a -> a
+ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  (-) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Num a => a -> a -> a
- forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  * :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(*) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Num a => a -> a -> a
* forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  negate :: NamedTensor device dtype shape -> NamedTensor device dtype shape
negate = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => a -> a
negate forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
  abs :: NamedTensor device dtype shape -> NamedTensor device dtype shape
abs = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => a -> a
abs forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
  signum :: NamedTensor device dtype shape -> NamedTensor device dtype shape
signum = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => a -> a
signum forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
  fromInteger :: Integer -> NamedTensor device dtype shape
fromInteger = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Num a => Integer -> a
fromInteger

instance KnownDevice device => Fractional (NamedTensor device dtype shape) where
  NamedTensor device dtype shape
a / :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
/ NamedTensor device dtype shape
b = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall a b. (a -> b) -> a -> b
$ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a forall a. Fractional a => a -> a -> a
/ forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  recip :: NamedTensor device dtype shape -> NamedTensor device dtype shape
recip = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Fractional a => a -> a
recip forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed
  fromRational :: Rational -> NamedTensor device dtype shape
fromRational = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Fractional a => Rational -> a
fromRational

instance Show (NamedTensor device dtype shape) where
  show :: NamedTensor device dtype shape -> String
show = forall a. Show a => a -> String
show forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
toUnnamed

type family ReplaceDevice'' (tensor :: t) (device :: (D.DeviceType, Nat)) :: t where
  ReplaceDevice'' (Tensor device0 dtype shape) device1 = Tensor device1 dtype shape
  ReplaceDevice'' (NamedTensor device0 dtype shape) device1 = NamedTensor device1 dtype shape

type family ReplaceDType'' (tensor :: t) (dtype :: D.DType) :: t where
  ReplaceDType'' (Tensor device dtype0 shape) dtype1 = Tensor device dtype1 shape
  ReplaceDType'' (NamedTensor device dtype0 shape) dtype1 = NamedTensor device dtype1 shape