{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

module Torch.Index
  ( slice,
    lslice,
  )
where

import Control.Monad ((>=>))
import Data.Void
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Quote (QuasiQuoter (..))
import Language.Haskell.TH.Syntax hiding (Unsafe)
import Text.Megaparsec as M
import Text.Megaparsec.Char hiding (space)
import Text.Megaparsec.Char.Lexer
import Torch.Tensor

type Parser = Parsec Void String

sc :: Parser ()
sc :: Parser ()
sc = forall e s (m :: * -> *).
MonadParsec e s m =>
m () -> m () -> m () -> m ()
space forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
m ()
space1 forall (f :: * -> *) a. Alternative f => f a
empty forall (f :: * -> *) a. Alternative f => f a
empty

lexm :: Parser a -> Parser a
lexm :: forall a. Parser a -> Parser a
lexm = forall e s (m :: * -> *) a. MonadParsec e s m => m () -> m a -> m a
lexeme Parser ()
sc

parseSlice :: String -> Q [Exp]
parseSlice :: [Char] -> Q [Exp]
parseSlice [Char]
str =
  case forall e s a.
Parsec e s a -> [Char] -> s -> Either (ParseErrorBundle s e) a
M.runParser Parser [Exp]
parse' [Char]
"slice" [Char]
str of
    Left ParseErrorBundle [Char] Void
e -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show ParseErrorBundle [Char] Void
e
    Right [Exp]
v -> forall (m :: * -> *) a. Monad m => a -> m a
return [Exp]
v
  where
    parse' :: Parser [Exp]
    parse' :: Parser [Exp]
parse' = (Parser ()
sc forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try ParsecT Void [Char] Identity Exp
slice forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try ParsecT Void [Char] Identity Exp
bool forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try ParsecT Void [Char] Identity Exp
other forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ParsecT Void [Char] Identity Exp
number)) forall (m :: * -> *) a sep. MonadPlus m => m a -> m sep -> m [a]
`sepBy` forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
Token s -> m (Token s)
char Char
','
    other :: Parser Exp
    other :: ParsecT Void [Char] Identity Exp
other =
      ( do
          Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string ([Char]
"None" :: Tokens String)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'None
      )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( do
                Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string ([Char]
"Ellipsis" :: Tokens String)
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'Ellipsis
            )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( do
                Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string ([Char]
"..." :: Tokens String)
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'Ellipsis
            )
    bool :: Parser Exp
    bool :: ParsecT Void [Char] Identity Exp
bool =
      ( do
          Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
"True"
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'True
      )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( do
                Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
"False"
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'False
            )
    number :: Parser Exp
    number :: ParsecT Void [Char] Identity Exp
number =
      ( do
          Integer
v <- forall a. Parser a -> Parser a
lexm forall e s (m :: * -> *) a.
(MonadParsec e s m, Token s ~ Char, Num a) =>
m a
decimal
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lit -> Exp
LitE (Integer -> Lit
IntegerL Integer
v)
      )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( do
                Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
"-"
                Integer
v <- forall a. Parser a -> Parser a
lexm forall e s (m :: * -> *) a.
(MonadParsec e s m, Token s ~ Char, Num a) =>
m a
decimal
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lit -> Exp
LitE (Integer -> Lit
IntegerL (- Integer
v))
            )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( do
                [Char]
v <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) open close a.
Applicative m =>
m open -> m close -> m a -> m a
between (forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
Token s -> m (Token s)
char Char
'{') (forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
Token s -> m (Token s)
char Char
'}') (forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
some forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
m (Token s)
alphaNumChar)
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE ([Char] -> Name
mkName [Char]
v)
            )
    slice :: ParsecT Void [Char] Identity Exp
slice =
      forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
        ( do
            Exp
a <- ParsecT Void [Char] Identity Exp
number
            forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
            Exp
b <- ParsecT Void [Char] Identity Exp
number
            forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
            Exp
c <- ParsecT Void [Char] Identity Exp
number
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just Exp
a, forall a. a -> Maybe a
Just Exp
b, forall a. a -> Maybe a
Just Exp
c])
        )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
              Exp
b <- ParsecT Void [Char] Identity Exp
number
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
              Exp
c <- ParsecT Void [Char] Identity Exp
number
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None), forall a. a -> Maybe a
Just Exp
b, forall a. a -> Maybe a
Just Exp
c])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              Exp
a <- ParsecT Void [Char] Identity Exp
number
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
"::"
              Exp
c <- ParsecT Void [Char] Identity Exp
number
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just Exp
a, forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None), forall a. a -> Maybe a
Just Exp
c])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              Exp
a <- ParsecT Void [Char] Identity Exp
number
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
              Exp
b <- ParsecT Void [Char] Identity Exp
number
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just Exp
a, forall a. a -> Maybe a
Just Exp
b])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
"::"
              Exp
c <- ParsecT Void [Char] Identity Exp
number
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None), forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None), forall a. a -> Maybe a
Just Exp
c])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
              Exp
b <- ParsecT Void [Char] Identity Exp
number
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None), forall a. a -> Maybe a
Just Exp
b])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
              Exp
b <- ParsecT Void [Char] Identity Exp
number
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None), forall a. a -> Maybe a
Just Exp
b])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              Exp
a <- ParsecT Void [Char] Identity Exp
number
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
"::"
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just Exp
a, forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None)])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              Exp
a <- ParsecT Void [Char] Identity Exp
number
              forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) ([Maybe Exp] -> Exp
TupE [forall a. a -> Maybe a
Just Exp
a, forall a. a -> Maybe a
Just (Name -> Exp
ConE 'None)])
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try
          ( do
              Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
"::"
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) (Name -> Exp
ConE '())
          )
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( do
                Tokens [Char]
_ <- forall a. Parser a -> Parser a
lexm forall a b. (a -> b) -> a -> b
$ forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
string Tokens [Char]
":"
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Slice) (Name -> Exp
ConE '())
            )

-- | Generate a slice from a [python compatible expression](https://pytorch.org/cppdocs/notes/tensor_indexing.html).
-- When you take the odd-numbered element of tensor with `tensor[1::2]` in python,
-- you can write `tensor ! [slice|1::2|]` in hasktorch.
slice :: QuasiQuoter
slice :: QuasiQuoter
slice =
  QuasiQuoter
    { quoteExp :: [Char] -> Q Exp
quoteExp = [Char] -> Q [Exp]
parseSlice forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> [Exp] -> Q Exp
qconcat,
      quotePat :: [Char] -> Q Pat
quotePat = forall a. HasCallStack => [Char] -> a
error [Char]
"quotePat is not implemented for slice.",
      quoteDec :: [Char] -> Q [Dec]
quoteDec = forall a. HasCallStack => [Char] -> a
error [Char]
"quoteDec is not implemented for slice.",
      quoteType :: [Char] -> Q Type
quoteType = forall a. HasCallStack => [Char] -> a
error [Char]
"quoteType is not implemented for slice."
    }
  where
    qconcat :: [Exp] -> Q Exp
    qconcat :: [Exp] -> Q Exp
qconcat [Exp
exp] = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
exp
    qconcat [Exp]
exps = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Maybe Exp] -> Exp
TupE forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Maybe a
Just [Exp]
exps

-- | Generate a lens from a [python compatible expression](https://pytorch.org/cppdocs/notes/tensor_indexing.html).
-- When you take the odd-numbered elements of tensor with `tensor[1::2]` in python,
-- you can write `tensor ^. [lslice|1::2|]` in hasktorch.
-- When you put 2 in the odd numbered elements of the tensor,
-- you can write `tensor & [lslice|1::2|] ~. 2`.
lslice :: QuasiQuoter
lslice :: QuasiQuoter
lslice =
  QuasiQuoter
    { quoteExp :: [Char] -> Q Exp
quoteExp = [Char] -> Q [Exp]
parseSlice forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> [Exp] -> Q Exp
qconcat,
      quotePat :: [Char] -> Q Pat
quotePat = forall a. HasCallStack => [Char] -> a
error [Char]
"quotePat is not implemented for slice.",
      quoteDec :: [Char] -> Q [Dec]
quoteDec = forall a. HasCallStack => [Char] -> a
error [Char]
"quoteDec is not implemented for slice.",
      quoteType :: [Char] -> Q Type
quoteType = forall a. HasCallStack => [Char] -> a
error [Char]
"quoteType is not implemented for slice."
    }
  where
    qconcat :: [Exp] -> Q Exp
    qconcat :: [Exp] -> Q Exp
qconcat [Exp
exp] = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'toLens) Exp
exp
    qconcat [Exp]
exps = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'toLens) forall a b. (a -> b) -> a -> b
$ [Maybe Exp] -> Exp
TupE forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Maybe a
Just [Exp]
exps