{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.NamedTensor where

import Data.Default.Class
import Data.Kind
import Data.Maybe (fromJust)
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import GHC.Exts
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.Lens
import qualified Torch.Tensor as D
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Tensor

class NamedTensorLike a where
  type ToNestedList a :: Type
  toNestedList :: a -> ToNestedList a
  asNamedTensor :: a -> NamedTensor '( 'D.CPU, 0) (ToDType a) (ToShape a)
  fromNestedList :: ToNestedList a -> a
  fromNamedTensor :: NamedTensor '( 'D.CPU, 0) (ToDType a) (ToShape a) -> a

instance NamedTensorLike Bool where
  type ToNestedList Bool = Bool
  toNestedList :: Bool -> ToNestedList Bool
toNestedList = forall a. a -> a
id
  asNamedTensor :: Bool -> NamedTensor '( 'CPU, 0) (ToDType Bool) (ToShape Bool)
asNamedTensor = forall t (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor
  fromNestedList :: ToNestedList Bool -> Bool
fromNestedList = forall a. a -> a
id
  fromNamedTensor :: NamedTensor '( 'CPU, 0) (ToDType Bool) (ToShape Bool) -> Bool
fromNamedTensor = forall a. TensorLike a => Tensor -> a
D.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Unnamed t => t -> Tensor
toDynamic

instance NamedTensorLike Int where
  type ToNestedList Int = Int
  toNestedList :: Int -> ToNestedList Int
toNestedList = forall a. a -> a
id
  asNamedTensor :: Int -> NamedTensor '( 'CPU, 0) (ToDType Int) (ToShape Int)
asNamedTensor = forall t (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor
  fromNestedList :: ToNestedList Int -> Int
fromNestedList = forall a. a -> a
id
  fromNamedTensor :: NamedTensor '( 'CPU, 0) (ToDType Int) (ToShape Int) -> Int
fromNamedTensor = forall a. TensorLike a => Tensor -> a
D.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Unnamed t => t -> Tensor
toDynamic

instance NamedTensorLike Float where
  type ToNestedList Float = Float
  toNestedList :: Float -> ToNestedList Float
toNestedList = forall a. a -> a
id
  asNamedTensor :: Float -> NamedTensor '( 'CPU, 0) (ToDType Float) (ToShape Float)
asNamedTensor = forall t (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor
  fromNestedList :: ToNestedList Float -> Float
fromNestedList = forall a. a -> a
id
  fromNamedTensor :: NamedTensor '( 'CPU, 0) (ToDType Float) (ToShape Float) -> Float
fromNamedTensor = forall a. TensorLike a => Tensor -> a
D.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Unnamed t => t -> Tensor
toDynamic

instance NamedTensorLike Double where
  type ToNestedList Double = Double
  toNestedList :: Double -> ToNestedList Double
toNestedList = forall a. a -> a
id
  asNamedTensor :: Double -> NamedTensor '( 'CPU, 0) (ToDType Double) (ToShape Double)
asNamedTensor = forall t (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor
  fromNestedList :: ToNestedList Double -> Double
fromNestedList = forall a. a -> a
id
  fromNamedTensor :: NamedTensor '( 'CPU, 0) (ToDType Double) (ToShape Double) -> Double
fromNamedTensor = forall a. TensorLike a => Tensor -> a
D.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Unnamed t => t -> Tensor
toDynamic

instance (KnownNat n, D.TensorLike (ToNestedList a), NamedTensorLike a) => NamedTensorLike (Vector n a) where
  type ToNestedList (Vector n a) = [ToNestedList a]
  toNestedList :: Vector n a -> ToNestedList (Vector n a)
toNestedList Vector n a
v = forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. NamedTensorLike a => a -> ToNestedList a
toNestedList (forall (n :: Natural) a. Vector n a -> [a]
V.toList Vector n a
v)
  asNamedTensor :: Vector n a
-> NamedTensor
     '( 'CPU, 0) (ToDType (Vector n a)) (ToShape (Vector n a))
asNamedTensor Vector n a
v = forall t (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor forall a b. (a -> b) -> a -> b
$ forall a. NamedTensorLike a => a -> ToNestedList a
toNestedList Vector n a
v
  fromNestedList :: ToNestedList (Vector n a) -> Vector n a
fromNestedList = forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. NamedTensorLike a => ToNestedList a -> a
fromNestedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Maybe a -> a
fromJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. KnownNat n => [a] -> Maybe (Vector n a)
V.fromList
  fromNamedTensor :: NamedTensor
  '( 'CPU, 0) (ToDType (Vector n a)) (ToShape (Vector n a))
-> Vector n a
fromNamedTensor = forall a. NamedTensorLike a => ToNestedList a -> a
fromNestedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => Tensor -> a
D.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Unnamed t => t -> Tensor
toDynamic

instance {-# OVERLAPS #-} (Coercible (vec n a) (Vector n a), KnownNat n, D.TensorLike (ToNestedList a), NamedTensorLike a) => NamedTensorLike (vec n a) where
  type ToNestedList (vec n a) = [ToNestedList a]
  toNestedList :: vec n a -> ToNestedList (vec n a)
toNestedList vec n a
v = forall a b. (a -> b) -> [a] -> [b]
map (forall a. NamedTensorLike a => a -> ToNestedList a
toNestedList @a) (forall (n :: Natural) a. Vector n a -> [a]
V.toList (coerce :: forall a b. Coercible a b => a -> b
coerce vec n a
v :: Vector n a))
  asNamedTensor :: vec n a
-> NamedTensor '( 'CPU, 0) (ToDType (vec n a)) (ToShape (vec n a))
asNamedTensor vec n a
v = forall t (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor forall a b. (a -> b) -> a -> b
$ forall a. NamedTensorLike a => a -> ToNestedList a
toNestedList vec n a
v
  fromNestedList :: ToNestedList (vec n a) -> vec n a
fromNestedList ToNestedList (vec n a)
v = coerce :: forall a b. Coercible a b => a -> b
coerce (forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. NamedTensorLike a => ToNestedList a -> a
fromNestedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Maybe a -> a
fromJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. KnownNat n => [a] -> Maybe (Vector n a)
V.fromList forall a b. (a -> b) -> a -> b
$ ToNestedList (vec n a)
v :: Vector n a)
  fromNamedTensor :: NamedTensor '( 'CPU, 0) (ToDType (vec n a)) (ToShape (vec n a))
-> vec n a
fromNamedTensor = forall a. NamedTensorLike a => ToNestedList a -> a
fromNestedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => Tensor -> a
D.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Unnamed t => t -> Tensor
toDynamic

instance {-# OVERLAPS #-} (Generic (g a), Default (g a), HasTypes (g a) a, KnownNat (ToNat g), D.TensorLike (ToNestedList a), NamedTensorLike a) => NamedTensorLike (g a) where
  type ToNestedList (g a) = [ToNestedList a]
  toNestedList :: g a -> ToNestedList (g a)
toNestedList g a
v = forall a b. (a -> b) -> [a] -> [b]
map (forall a. NamedTensorLike a => a -> ToNestedList a
toNestedList @a) (forall a s. Traversal' s a -> s -> [a]
flattenValues (forall a s. HasTypes s a => Traversal' s a
types @a) g a
v)
  asNamedTensor :: g a -> NamedTensor '( 'CPU, 0) (ToDType (g a)) (ToShape (g a))
asNamedTensor g a
v = forall t (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
D.asTensor forall a b. (a -> b) -> a -> b
$ forall a. NamedTensorLike a => a -> ToNestedList a
toNestedList g a
v
  fromNestedList :: ToNestedList (g a) -> g a
fromNestedList ToNestedList (g a)
v = forall a s. Traversal' s a -> s -> [a] -> s
replaceValues (forall a s. HasTypes s a => Traversal' s a
types @a) forall a. Default a => a
def (forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. NamedTensorLike a => ToNestedList a -> a
fromNestedList ToNestedList (g a)
v)
  fromNamedTensor :: NamedTensor '( 'CPU, 0) (ToDType (g a)) (ToShape (g a)) -> g a
fromNamedTensor = forall a. NamedTensorLike a => ToNestedList a -> a
fromNestedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => Tensor -> a
D.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Unnamed t => t -> Tensor
toDynamic