{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.Lens where

import Control.Applicative (liftA2)
import Control.Monad.State.Strict
import Data.Kind
import Data.Maybe (fromJust)
import Data.Proxy
import Data.Reflection hiding (D)
import Data.Type.Bool
import Data.Vector.Sized (Vector)
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D hiding (select)
import qualified Torch.Functional.Internal as I
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import Torch.Lens (Lens, Lens', Traversal, Traversal')
import qualified Torch.Tensor as T
import Torch.Typed.Auxiliary hiding (If)
import Torch.Typed.Tensor

class HasName (name :: Type -> Type) shape where
  name :: Traversal' (NamedTensor device dtype shape) (NamedTensor device dtype (DropName name shape))
  default name :: (KnownNat (NamedIdx name shape)) => Traversal' (NamedTensor device dtype shape) (NamedTensor device dtype (DropName name shape))
  name NamedTensor device dtype (DropName name shape)
-> f (NamedTensor device dtype (DropName name shape))
func NamedTensor device dtype shape
s = f (NamedTensor device dtype shape)
func'
    where
      dimension :: Int
      dimension :: Int
dimension = forall (n :: Nat). KnownNat n => Int
natValI @(NamedIdx name shape)
      func' :: f (NamedTensor device dtype shape)
func' = (\[NamedTensor device dtype (DropName name shape)]
v -> (forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
D.stack (Int -> Dim
D.Dim Int
dimension) (forall a b. (a -> b) -> [a] -> [b]
map forall t. Unnamed t => t -> Tensor
toDynamic [NamedTensor device dtype (DropName name shape)]
v))) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {f :: * -> *} {a}. Applicative f => [f a] -> f [a] -> f [a]
swapA (forall a b. (a -> b) -> [a] -> [b]
map NamedTensor device dtype (DropName name shape)
-> f (NamedTensor device dtype (DropName name shape))
func forall (device :: (DeviceType, Nat)) (dtype :: DType).
[NamedTensor device dtype (DropName name shape)]
a') (forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
      s' :: Tensor
s' = forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype shape
s
      swapA :: [f a] -> f [a] -> f [a]
swapA [] f [a]
v = f [a]
v
      swapA (f a
x : [f a]
xs) f [a]
v = [f a] -> f [a] -> f [a]
swapA [f a]
xs (forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (\a
a [a]
b -> [a]
b forall a. [a] -> [a] -> [a]
++ [a
a]) f a
x f [a]
v)
      a' :: [NamedTensor device dtype (DropName name shape)]
      a' :: forall (device :: (DeviceType, Nat)) (dtype :: DType).
[NamedTensor device dtype (DropName name shape)]
a' = forall a b. (a -> b) -> [a] -> [b]
map (forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor) forall a b. (a -> b) -> a -> b
$ Tensor -> Int -> [Tensor]
I.unbind Tensor
s' Int
dimension

instance (KnownNat (NamedIdx name shape)) => HasName name shape

class HasField (field :: Symbol) shape where
  field :: Lens' (NamedTensor device dtype shape) (NamedTensor device dtype (DropField field shape))
  default field :: (FieldIdx field shape) => Lens' (NamedTensor device dtype shape) (NamedTensor device dtype (DropField field shape))
  field NamedTensor device dtype (DropField field shape)
-> f (NamedTensor device dtype (DropField field shape))
func NamedTensor device dtype shape
s = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
-> NamedTensor device dtype shape
func' (NamedTensor device dtype (DropField field shape)
-> f (NamedTensor device dtype (DropField field shape))
func forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
a')
    where
      index :: [Maybe Int]
index = forall (field :: Symbol) (a :: Shape).
FieldIdx field a =>
Proxy a -> [Maybe Int]
fieldIdx @field @shape forall {k} (t :: k). Proxy t
Proxy
      func' :: NamedTensor device dtype (DropField field shape) -> NamedTensor device dtype shape
      func' :: forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
-> NamedTensor device dtype shape
func' NamedTensor device dtype (DropField field shape)
v = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
T.maskedFill Tensor
s' [Maybe Int]
index (forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype (DropField field shape)
v)
      s' :: Tensor
s' = forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype shape
s
      a' :: NamedTensor device dtype (DropField field shape)
      a' :: forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
a' = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ (Tensor
s' forall a. TensorIndex a => Tensor -> a -> Tensor
T.! [Maybe Int]
index)

instance {-# OVERLAPS #-} FieldIdx field shape => HasField field shape

type family GHasField (field :: Symbol) f :: Bool where
  GHasField field (S1 ('MetaSel ('Just field) _ _ _) _) = 'True
  GHasField field (S1 ('MetaSel _ _ _ _) _) = 'False
  GHasField field (D1 _ f) = GHasField field f
  GHasField field (C1 _ f) = GHasField field f
  GHasField field (l :*: r) = GHasField field l || GHasField field r
  GHasField field (l :+: r) = GHasField field l || GHasField field r
  GHasField field (K1 _ _) = 'False
  GHasField field U1 = 'False
  GHasField field (Vector n) = 'False
  GHasField field a = GHasField field (Rep (a ()))

type family DropField (field :: Symbol) (a :: [Type -> Type]) :: [Type -> Type] where
  DropField field '[] = '[]
  DropField field (x ': xs) = If (GHasField field x) xs (x ': DropField field xs)

type family DropName (name :: Type -> Type) (a :: [Type -> Type]) :: [Type -> Type] where
  DropName name '[] = '[]
  DropName name (name ': xs) = xs
  DropName name (x ': xs) = x ': DropName name xs

instance {-# OVERLAPS #-} T.TensorIndex [Maybe Int] where
  pushIndex :: [RawTensorIndex] -> [Maybe Int] -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec [Maybe Int]
list_of_maybe_int = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    [RawTensorIndex]
idx <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Maybe Int]
list_of_maybe_int forall a b. (a -> b) -> a -> b
$ \Maybe Int
i -> do
      case Maybe Int
i of
        Maybe Int
Nothing -> ForeignPtr TensorIndex -> RawTensorIndex
T.RawTensorIndex forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 forall a. Bounded a => a
maxBound CInt
1
        Just Int
v -> ForeignPtr TensorIndex -> RawTensorIndex
T.RawTensorIndex forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
v)
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
idx forall a. [a] -> [a] -> [a]
++ [RawTensorIndex]
vec

type family NamedIdx (name :: Type -> Type) (shape :: [Type -> Type]) :: Nat where
  NamedIdx name '[] = TypeError (Text "There is not the name in the shape.")
  NamedIdx name (name ': xs) = 0
  NamedIdx name (x ': xs) = NamedIdx name xs + 1

class FieldIdx (field :: Symbol) (a :: [Type -> Type]) where
  -- | Return field-id
  fieldIdx :: Proxy a -> [Maybe Int]

instance FieldIdx field '[] where
  fieldIdx :: Proxy '[] -> [Maybe Int]
fieldIdx Proxy '[]
_ = []

instance (FieldId field (x ()), FieldIdx field xs) => FieldIdx field (x ': xs) where
  fieldIdx :: Proxy (x : xs) -> [Maybe Int]
fieldIdx Proxy (x : xs)
_ = forall (field :: Symbol) a. FieldId field a => Proxy a -> Maybe Int
fieldId @field @(x ()) forall {k} (t :: k). Proxy t
Proxy forall a. a -> [a] -> [a]
: forall (field :: Symbol) (a :: Shape).
FieldIdx field a =>
Proxy a -> [Maybe Int]
fieldIdx @field @xs forall {k} (t :: k). Proxy t
Proxy

class FieldId (field :: Symbol) a where
  -- | Return field-id
  fieldId :: Proxy a -> Maybe Int
  default fieldId :: (Generic a, GFieldId field (Rep a)) => Proxy a -> Maybe Int
  fieldId Proxy a
_ = forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> Maybe Int
gfieldId @field (forall {k} (t :: k). Proxy t
Proxy :: Proxy (Rep a))

instance FieldId field (Vector n v) where
  fieldId :: Proxy (Vector n v) -> Maybe Int
fieldId Proxy (Vector n v)
_ = forall a. Maybe a
Nothing

instance {-# OVERLAPS #-} (Generic s, GFieldId field (Rep s)) => FieldId field s

class GFieldId (field :: Symbol) (a :: Type -> Type) where
  gfieldId :: Proxy a -> Maybe Int
  gfieldId Proxy a
p = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field @a Proxy a
p
  gfieldId' :: Proxy a -> (Maybe Int, Int)

instance (GFieldId field f) => GFieldId field (M1 D t f) where
  gfieldId' :: Proxy (M1 D t f) -> (Maybe Int, Int)
gfieldId' Proxy (M1 D t f)
_ = forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (forall {k} (t :: k). Proxy t
Proxy :: Proxy f)

instance (GFieldId field f) => GFieldId field (M1 C t f) where
  gfieldId' :: Proxy (M1 C t f) -> (Maybe Int, Int)
gfieldId' Proxy (M1 C t f)
_ = forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (forall {k} (t :: k). Proxy t
Proxy :: Proxy f)

instance (KnownSymbol field, KnownSymbol field_) => GFieldId field (S1 ('MetaSel ('Just field_) p f b) (Rec0 a)) where
  gfieldId' :: Proxy (S1 ('MetaSel ('Just field_) p f b) (Rec0 a))
-> (Maybe Int, Int)
gfieldId' Proxy (S1 ('MetaSel ('Just field_) p f b) (Rec0 a))
_ =
    if forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
Proxy :: Proxy field) forall a. Eq a => a -> a -> Bool
== forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
Proxy :: Proxy field_)
      then (forall a. a -> Maybe a
Just Int
0, Int
1)
      else (forall a. Maybe a
Nothing, Int
1)

instance GFieldId field (K1 c f) where
  gfieldId' :: Proxy (K1 c f) -> (Maybe Int, Int)
gfieldId' Proxy (K1 c f)
_ = (forall a. Maybe a
Nothing, Int
1)

instance GFieldId field U1 where
  gfieldId' :: Proxy U1 -> (Maybe Int, Int)
gfieldId' Proxy U1
_ = (forall a. Maybe a
Nothing, Int
1)

instance (GFieldId field f, GFieldId field g) => GFieldId field (f :*: g) where
  gfieldId' :: Proxy (f :*: g) -> (Maybe Int, Int)
gfieldId' Proxy (f :*: g)
_ =
    case (forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (forall {k} (t :: k). Proxy t
Proxy :: Proxy f), forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (forall {k} (t :: k). Proxy t
Proxy :: Proxy g)) of
      ((Maybe Int
Nothing, Int
t0), (Maybe Int
Nothing, Int
t1)) -> (forall a. Maybe a
Nothing, Int
t0 forall a. Num a => a -> a -> a
+ Int
t1)
      ((Maybe Int
Nothing, Int
t0), (Just Int
v1, Int
t1)) -> (forall a. a -> Maybe a
Just (Int
v1 forall a. Num a => a -> a -> a
+ Int
t0), Int
t1 forall a. Num a => a -> a -> a
+ Int
t0)
      ((Just Int
v0, Int
t0), (Maybe Int
_, Int
t1)) -> (forall a. a -> Maybe a
Just Int
v0, Int
t0 forall a. Num a => a -> a -> a
+ Int
t1)

instance (GFieldId field f, GFieldId field g) => GFieldId field (f :+: g) where
  gfieldId' :: Proxy (f :+: g) -> (Maybe Int, Int)
gfieldId' Proxy (f :+: g)
_ =
    case (forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (forall {k} (t :: k). Proxy t
Proxy :: Proxy f), forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (forall {k} (t :: k). Proxy t
Proxy :: Proxy g)) of
      ((Maybe Int
Nothing, Int
_), (Maybe Int, Int)
a1) -> (Maybe Int, Int)
a1
      (a0 :: (Maybe Int, Int)
a0@(Just Int
_, Int
_), (Maybe Int, Int)
_) -> (Maybe Int, Int)
a0