{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Lens where

import Control.Monad.Identity
import Control.Monad.State.Strict
import GHC.Generics

-- | Type synonym for lens
type Lens s t a b = forall f. Functor f => (a -> f b) -> s -> f t

type Lens' s a = Lens s s a a

type Traversal s t a b = forall f. Applicative f => (a -> f b) -> s -> f t

type Traversal' s a = Traversal s s a a

class HasTypes s a where
  types_ :: Traversal' s a
  default types_ :: (Generic s, GHasTypes (Rep s) a) => Traversal' s a
  types_ a -> f a
func s
s = forall a x. Generic a => Rep a x -> a
to forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func (forall a x. Generic a => a -> Rep a x
from s
s)
  {-# INLINE types_ #-}

instance {-# OVERLAPS #-} (Generic s, GHasTypes (Rep s) a) => HasTypes s a

over :: Traversal' s a -> (a -> a) -> s -> s
over :: forall s a. Traversal' s a -> (a -> a) -> s -> s
over Traversal' s a
l a -> a
f = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. Traversal' s a
l (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f)

flattenValues :: forall a s. Traversal' s a -> s -> [a]
flattenValues :: forall a s. Traversal' s a -> s -> [a]
flattenValues Traversal' s a
func s
orgData = forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState (Traversal' s a
func a -> State [a] a
push s
orgData) []
  where
    push :: a -> State [a] a
    push :: a -> State [a] a
push a
v = do
      [a]
d <- forall s (m :: * -> *). MonadState s m => m s
get
      forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ a
v forall a. a -> [a] -> [a]
: [a]
d
      forall (m :: * -> *) a. Monad m => a -> m a
return a
v

replaceValues :: forall a s. Traversal' s a -> s -> [a] -> s
replaceValues :: forall a s. Traversal' s a -> s -> [a] -> s
replaceValues Traversal' s a
func s
orgData [a]
newValues = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState (Traversal' s a
func a -> State [a] a
pop s
orgData) [a]
newValues
  where
    pop :: a -> State [a] a
    pop :: a -> State [a] a
pop a
_ = do
      [a]
d <- forall s (m :: * -> *). MonadState s m => m s
get
      case [a]
d of
        [] -> forall a. HasCallStack => [Char] -> a
error [Char]
"Not enough values supplied to replaceValues"
        a
x : [a]
xs -> do
          forall s (m :: * -> *). MonadState s m => s -> m ()
put [a]
xs
          forall (m :: * -> *) a. Monad m => a -> m a
return a
x

types :: forall a s. HasTypes s a => Traversal' s a
types :: forall a s. HasTypes s a => Traversal' s a
types = forall s a. HasTypes s a => Traversal' s a
types_ @s @a

class GHasTypes s a where
  gtypes :: forall b. Traversal' (s b) a

instance GHasTypes U1 a where
  gtypes :: forall b. Traversal' (U1 b) a
gtypes a -> f a
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  {-# INLINE gtypes #-}

instance (GHasTypes f a, GHasTypes g a) => GHasTypes (f :+: g) a where
  gtypes :: forall b. Traversal' ((:+:) f g b) a
gtypes a -> f a
func (L1 f b
x) = forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func f b
x
  gtypes a -> f a
func (R1 g b
x) = forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func g b
x

instance (GHasTypes f a, GHasTypes g a) => GHasTypes (f :*: g) a where
  gtypes :: forall b. Traversal' ((:*:) f g b) a
gtypes a -> f a
func (f b
x :*: g b
y) = forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func f b
x forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func g b
y
  {-# INLINE gtypes #-}

instance (HasTypes s a) => GHasTypes (K1 i s) a where
  gtypes :: forall b. Traversal' (K1 i s b) a
gtypes a -> f a
func (K1 s
x) = forall k i c (p :: k). c -> K1 i c p
K1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a s. HasTypes s a => Traversal' s a
types a -> f a
func s
x
  {-# INLINE gtypes #-}

instance GHasTypes s a => GHasTypes (M1 i t s) a where
  gtypes :: forall b. Traversal' (M1 i t s b) a
gtypes a -> f a
func (M1 s b
x) = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func s b
x
  {-# INLINE gtypes #-}

instance {-# OVERLAPS #-} (HasTypes s a) => HasTypes [s] a where
  types_ :: Traversal' [s] a
types_ a -> f a
func [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
  types_ a -> f a
func (s
x : [s]
xs) = (:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s a. HasTypes s a => Traversal' s a
types_ a -> f a
func s
x forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall s a. HasTypes s a => Traversal' s a
types_ a -> f a
func [s]
xs
  {-# INLINE types_ #-}

instance {-# OVERLAPS #-} (HasTypes s0 a, HasTypes s1 a) => HasTypes (s0, s1) a where
  types_ :: Traversal' (s0, s1) a
types_ a -> f a
func (s0
s0, s1
s1) = (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s a. HasTypes s a => Traversal' s a
types_ a -> f a
func s0
s0 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall s a. HasTypes s a => Traversal' s a
types_ a -> f a
func s1
s1
  {-# INLINE types_ #-}

instance {-# OVERLAPS #-} (HasTypes s0 a, HasTypes s1 a, HasTypes s2 a) => HasTypes (s0, s1, s2) a where
  types_ :: Traversal' (s0, s1, s2) a
types_ a -> f a
func (s0
s0, s1
s1, s2
s2) = (,,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s a. HasTypes s a => Traversal' s a
types_ a -> f a
func s0
s0 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall s a. HasTypes s a => Traversal' s a
types_ a -> f a
func s1
s1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall s a. HasTypes s a => Traversal' s a
types_ a -> f a
func s2
s2
  {-# INLINE types_ #-}