{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

module Torch.GraduallyTyped.Tensor.Indexing
  ( IndexType (..),
    SIndexType (..),
    Indices (..),
    SIndices (..),
    IndexDims,
    (!),
    slice,
  )
where

import Control.Arrow ((>>>))
import Control.Monad (forM_, void, (<=<))
import Control.Monad.Catch (MonadThrow)
import Control.Monad.Trans (lift)
import Data.Coerce (coerce)
import Data.Foldable (asum)
import Data.Kind (Type)
import Data.Singletons (Demote, SingI, SingKind, SomeSing (..), fromSing, sing, toSing, withSomeSing)
import Data.Singletons.TH (genSingletons)
import Data.Type.Equality (type (==))
import Data.Void (Void)
import Foreign (fromBool)
import GHC.TypeLits (Div, ErrorMessage (..), Nat, Symbol, type (+), type (-), type (<=?))
import GHC.Natural (Natural)
import qualified Language.Haskell.TH as TH
import Language.Haskell.TH.Quote (QuasiQuoter (..))
import Text.Megaparsec as M
import qualified Text.Megaparsec.Char as M
import qualified Text.Megaparsec.Char.Lexer as L
import Torch.GraduallyTyped.Index.Type (DemotedIndex (..), Index (..), SIndex (..))
import Torch.GraduallyTyped.Prelude (If, IsChecked (..), forgetIsChecked, type (<?))
import Torch.GraduallyTyped.Prelude.Bool (SBool (..))
import Torch.GraduallyTyped.Prelude.List (Reverse, SList (..), Sing)
import Torch.GraduallyTyped.Shape.Class (PrependDimF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor (..))
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import Type.Errors.Pretty (TypeError, type (%), type (<>))

data IndexType a
  = NewAxis
  | Ellipsis
  | SliceAll
  | SliceAt a
  | SliceBool Bool
  | SliceFrom a
  | SliceUpTo a
  | SliceWithStep a
  | SliceFromUpTo a a
  | SliceFromWithStep a a
  | SliceUpToWithStep a a
  | SliceFromUpToWithStep a a a
  deriving (Int -> IndexType a -> ShowS
forall a. Show a => Int -> IndexType a -> ShowS
forall a. Show a => [IndexType a] -> ShowS
forall a. Show a => IndexType a -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [IndexType a] -> ShowS
$cshowList :: forall a. Show a => [IndexType a] -> ShowS
show :: IndexType a -> [Char]
$cshow :: forall a. Show a => IndexType a -> [Char]
showsPrec :: Int -> IndexType a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> IndexType a -> ShowS
Show, IndexType a -> IndexType a -> Bool
forall a. Eq a => IndexType a -> IndexType a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IndexType a -> IndexType a -> Bool
$c/= :: forall a. Eq a => IndexType a -> IndexType a -> Bool
== :: IndexType a -> IndexType a -> Bool
$c== :: forall a. Eq a => IndexType a -> IndexType a -> Bool
Eq, forall a b. a -> IndexType b -> IndexType a
forall a b. (a -> b) -> IndexType a -> IndexType b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> IndexType b -> IndexType a
$c<$ :: forall a b. a -> IndexType b -> IndexType a
fmap :: forall a b. (a -> b) -> IndexType a -> IndexType b
$cfmap :: forall a b. (a -> b) -> IndexType a -> IndexType b
Functor)

genSingletons [''IndexType]

type ReverseShape :: Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family ReverseShape shape where
  ReverseShape 'UncheckedShape = 'UncheckedShape
  ReverseShape ('Shape dims) = 'Shape (Reverse dims)

type ErrorOnEllipsis :: [IndexType (Index Nat)] -> [IndexType (Index Nat)]
type family ErrorOnEllipsis indices where
  ErrorOnEllipsis '[] = '[]
  ErrorOnEllipsis ('Ellipsis ': ixs) = TypeError ('Text "Indices can only contain a single ellipsis ('...').")
  ErrorOnEllipsis (ix ': ixs) = ix ': ErrorOnEllipsis ixs

type StepZeroErrorMessage = 'Text "Slice step cannot be zero"

-- | Calculate the size of the dimension with step.
--
-- >>> :kind! Stepped 8 1
-- Stepped 8 1 :: Natural
-- = 8
-- >>> :kind! Stepped 5 2
-- Stepped 5 2 :: Natural
-- = 3
-- >>> :kind! Stepped 6 3
-- Stepped 6 3 :: Natural
-- = 2
type Stepped :: Nat -> Nat -> Nat
type family Stepped length step where
  Stepped _ 0 = TypeError StepZeroErrorMessage
  Stepped 0 _ = 0
  Stepped length step = (length - 1) `Div` step + 1

type family CheckUpTo (upTo :: Nat) ok where
  CheckUpTo upTo ok =
    If
      (upTo == 0)
      (TypeError ('Text "Slice 'upTo' type parameter must not be equal to zero"))
      ok

type family CheckFromSize (from :: Nat) (size :: Nat) ok where
  CheckFromSize from size ok =
    If
      (from <? size)
      ok
      ( TypeError
          ( "Slice 'from' type parameter must be smaller than the size of the indexed dimension:"
              % "    " <> "from < size"
              % "but"
              % "    " <> from <> " >= " <> size
          )
      )

type family CheckUpToSize (upTo :: Nat) (size :: Size Nat) ok where
  CheckUpToSize upTo 'UncheckedSize ok = CheckUpTo upTo ok
  CheckUpToSize upTo ('Size size) ok =
    CheckUpTo
      upTo
      ( If
          (upTo <=? size)
          ok
          ( TypeError
              ( "Slice 'upTo' type parameter must be less than or equal to the size of the indexed dimension:"
                  % "    " <> "upTo <= size"
                  % "but"
                  % "    " <> upTo <> " > " <> size
              )
          )
      )

type family CheckFromUpTo (from :: Nat) (upTo :: Nat) ok where
  CheckFromUpTo from upTo ok =
    If
      (from <? upTo)
      ok
      ( TypeError
          ( "Slice 'from' type parameter must be less than the 'upTo' type parameter:"
              % "    " <> "from < upTo"
              % "but"
              % "    " <> from <> " >= " <> upTo
          )
      )

type family CheckFromUpToSize (from :: Nat) (upTo :: Nat) (size :: Size Nat) ok where
  CheckFromUpToSize from upTo size ok = CheckFromUpTo from upTo (CheckUpToSize upTo size ok)

type family CheckSliceAt (at :: Nat) (size :: Nat) ok where
  CheckSliceAt at size ok =
    If
      (at <? size)
      ok
      ( TypeError
          ( "Index of 'SliceAt' must be less than the size of the indexed dimension:"
              % "    " <> "at < size"
              % "but"
              % "    " <> at <> " >= " <> size
          )
      )

type family CheckStep (step :: Index Nat) ok where
  CheckStep ('Index 0) _ = TypeError StepZeroErrorMessage
  CheckStep _ ok = ok

type IndexDimsImpl :: [IndexType (Index Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family IndexDimsImpl indices dims where
  IndexDimsImpl '[] dims = 'Shape dims
  IndexDimsImpl ('NewAxis ': ixs) dims = 'Dim ('Name "*") ('Size 1) `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('Ellipsis ': _) '[] = 'Shape '[]
  IndexDimsImpl ('Ellipsis ': ixs) dims = ReverseShape (IndexDimsImpl (Reverse (ErrorOnEllipsis ixs)) (Reverse dims))
  IndexDimsImpl ('SliceAll ': ixs) (dim ': dims) = dim `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceAt ('Index at) ': ixs) ('Dim name ('Size size) ': dims) = CheckSliceAt at size (IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceAt _ ': ixs) ('Dim name _ ': dims) = IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceBool 'False ': ixs) ('Dim name _ ': dims) = 'Dim name ('Size 0) `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceBool 'True ': ixs) ('Dim name _ ': dims) = 'Dim name ('Size 1) `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceFrom ('Index from) ': ixs) ('Dim name ('Size size) ': dims) =
    CheckFromSize from size ('Dim name ('Size (size - from)) `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceFrom _ ': ixs) ('Dim name _ ': dims) =
    'Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceUpTo ('Index upTo) ': ixs) ('Dim name size ': dims) =
    CheckUpToSize upTo size ('Dim name ('Size upTo) `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceUpTo _ ': ixs) ('Dim name _ ': dims) =
    'Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceWithStep ('Index step) ': ixs) ('Dim name ('Size size) ': dims) =
    'Dim name ('Size (Stepped size step)) `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceWithStep step ': ixs) ('Dim name _ ': dims) =
    CheckStep step ('Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceFromUpTo ('Index from) ('Index upTo) ': ixs) ('Dim name size ': dims) =
    CheckFromUpToSize from upTo size ('Dim name ('Size (upTo - from)) `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceFromUpTo _ _ ': ixs) ('Dim name _ ': dims) =
    'Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims
  IndexDimsImpl ('SliceFromWithStep ('Index from) ('Index step) ': ixs) ('Dim name ('Size size) ': dims) =
    CheckFromSize from size ('Dim name ('Size (Stepped (size - from) step)) `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceFromWithStep _ step ': ixs) ('Dim name _ ': dims) =
    CheckStep step ('Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceUpToWithStep ('Index upTo) ('Index step) ': ixs) ('Dim name size ': dims) =
    CheckUpToSize upTo size ('Dim name ('Size (Stepped upTo step)) `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceUpToWithStep _ step ': ixs) ('Dim name _ ': dims) =
    CheckStep step ('Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceFromUpToWithStep ('Index from) ('Index upTo) ('Index step) ': ixs) ('Dim name size ': dims) =
    CheckFromUpToSize from upTo size ('Dim name ('Size (Stepped (upTo - from) step)) `PrependDimF` IndexDimsImpl ixs dims)
  IndexDimsImpl ('SliceFromUpToWithStep _ _ step ': ixs) ('Dim name _ ': dims) =
    CheckStep step ('Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims)

type family IndexDims indices shape where
  IndexDims 'UncheckedIndices _ = 'UncheckedShape
  IndexDims _ 'UncheckedShape = 'UncheckedShape
  IndexDims ('Indices indices) ('Shape dims) = IndexDimsImpl indices dims

data Indices (indexTypes :: Type) where
  UncheckedIndices :: forall indexTypes. Indices indexTypes
  Indices :: forall indexTypes. indexTypes -> Indices indexTypes

data SIndices (indices :: Indices [IndexType (Index Nat)]) where
  SUncheckedIndices :: [IndexType Integer] -> SIndices 'UncheckedIndices
  SIndices :: forall indexTypes. SList indexTypes -> SIndices ('Indices indexTypes)

type instance Sing = SIndices

instance SingI indexTypes => SingI ('Indices (indexTypes :: [IndexType (Index Nat)])) where
  sing :: Sing ('Indices indexTypes)
sing = forall (indexTypes :: [IndexType (Index Nat)]).
SList indexTypes -> SIndices ('Indices indexTypes)
SIndices forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @indexTypes

instance SingKind (Indices [IndexType (Index Nat)]) where
  type Demote (Indices [IndexType (Index Nat)]) = IsChecked [IndexType (IsChecked Integer)]
  fromSing :: forall (a :: Indices [IndexType (Index Nat)]).
Sing a -> Demote (Indices [IndexType (Index Nat)])
fromSing (SUncheckedIndices [IndexType Integer]
indexTypes) = forall a. a -> IsChecked a
Unchecked forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> IsChecked a
Unchecked forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [IndexType Integer]
indexTypes
  fromSing (SIndices SList indexTypes
indexTypes) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SList indexTypes
indexTypes
  toSing :: Demote (Indices [IndexType (Index Nat)])
-> SomeSing (Indices [IndexType (Index Nat)])
toSing (Unchecked [IndexType (IsChecked Integer)]
indexTypes) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. [IndexType Integer] -> SIndices 'UncheckedIndices
SUncheckedIndices forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. IsChecked a -> a
forgetIsChecked forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [IndexType (IsChecked Integer)]
indexTypes
  toSing (Checked [IndexType (IsChecked Integer)]
indexTypes) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing ((forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) Integer -> DemotedIndex
DemotedIndex [IndexType (IsChecked Integer)]
indexTypes) forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (indexTypes :: [IndexType (Index Nat)]).
SList indexTypes -> SIndices ('Indices indexTypes)
SIndices

(!) ::
  forall indices requiresGradient layout device dataType shape m.
  MonadThrow m =>
  Tensor requiresGradient layout device dataType shape ->
  SIndices indices ->
  m (Tensor requiresGradient layout device dataType (IndexDims indices shape))
(UnsafeTensor ForeignPtr Tensor
t) ! :: forall (indices :: Indices [IndexType (Index Nat)])
       (requiresGradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> SIndices indices
-> m (Tensor
        requiresGradient layout device dataType (IndexDims indices shape))
! SIndices indices
sIndices = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  ForeignPtr (StdVector TensorIndex)
indexList <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
  [ForeignPtr TensorIndex]
tensorIndices <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse IndexType Integer -> IO (ForeignPtr TensorIndex)
toTensorIndex [IndexType Integer]
indices
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ForeignPtr TensorIndex]
tensorIndices forall a b. (a -> b) -> a -> b
$ ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
indexList
  forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex) -> IO (ForeignPtr Tensor)
ATen.index ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
indexList
  where
    indices :: [IndexType Integer]
indices = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. IsChecked a -> a
forgetIsChecked forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SIndices indices
sIndices)
    toTensorIndex :: IndexType Integer -> IO (ForeignPtr TensorIndex)
toTensorIndex =
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (Integral a, Num b) => a -> b
fromIntegral forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
        IndexType CInt
NewAxis -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithNone
        IndexType CInt
Ellipsis -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithEllipsis
        SliceAt CInt
at -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt CInt
at
        SliceBool Bool
b -> CBool -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithBool (forall a. Num a => Bool -> a
fromBool Bool
b)
        IndexType CInt
SliceAll -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 forall a. Bounded a => a
maxBound CInt
1
        SliceFrom CInt
from -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
from forall a. Bounded a => a
maxBound CInt
1
        SliceUpTo CInt
upTo -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 CInt
upTo CInt
1
        SliceWithStep CInt
step -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 forall a. Bounded a => a
maxBound CInt
step
        SliceFromUpTo CInt
from CInt
upTo -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
from CInt
upTo CInt
1
        SliceFromWithStep CInt
from CInt
step -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
from forall a. Bounded a => a
maxBound CInt
step
        SliceUpToWithStep CInt
upTo CInt
step -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 CInt
upTo CInt
step
        SliceFromUpToWithStep CInt
from CInt
upTo CInt
step -> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
from CInt
upTo CInt
step

type Parser = ParsecT Void String TH.Q

sc :: Parser ()
sc :: Parser ()
sc = forall e s (m :: * -> *).
MonadParsec e s m =>
m () -> m () -> m () -> m ()
L.space forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
m ()
M.space1 forall (f :: * -> *) a. Alternative f => f a
empty forall (f :: * -> *) a. Alternative f => f a
empty

lexeme :: Parser a -> Parser a
lexeme :: forall a. Parser a -> Parser a
lexeme = forall e s (m :: * -> *) a. MonadParsec e s m => m () -> m a -> m a
L.lexeme Parser ()
sc

char :: Char -> Parser Char
char :: Char -> Parser Char
char = forall a. Parser a -> Parser a
lexeme forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
Token s -> m (Token s)
M.char

string :: String -> Parser String
string :: [Char] -> Parser [Char]
string = forall a. Parser a -> Parser a
lexeme forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e s (m :: * -> *).
MonadParsec e s m =>
Tokens s -> m (Tokens s)
M.string

parseSlice :: String -> TH.Q TH.Exp
parseSlice :: [Char] -> Q Exp
parseSlice = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s e.
(VisualStream s, TraversableStream s, ShowErrorComponent e) =>
ParseErrorBundle s e -> [Char]
errorBundlePretty) forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *) e s a.
Monad m =>
ParsecT e s m a
-> [Char] -> s -> m (Either (ParseErrorBundle s e) a)
M.runParserT Parser Exp
indicesP [Char]
""
  where
    indicesP :: Parser TH.Exp
    indicesP :: Parser Exp
indicesP = do
      [Exp]
indexExps <- Parser ()
sc forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try Parser Exp
sliceP forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try Parser Exp
boolP forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Exp
otherP) forall (m :: * -> *) a sep. MonadPlus m => m a -> m sep -> m [a]
`sepBy` Char -> Parser Char
char Char
',' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall e s (m :: * -> *). MonadParsec e s m => m ()
eof
      let indicesExp :: Exp
indicesExp = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Exp -> Exp -> Exp
TH.AppE forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp -> Exp
TH.AppE (Name -> Exp
TH.ConE 'SCons)) (Name -> Exp
TH.ConE 'SNil) [Exp]
indexExps
      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SIndices $(pure indicesExp)|]
    otherP :: Parser TH.Exp
    otherP :: Parser Exp
otherP =
      forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
        [ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SNewAxis|] forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ([Char] -> Parser [Char]
string [Char]
"+" forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [Char] -> Parser [Char]
string [Char]
"NewAxis"),
          forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SEllipsis|] forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ([Char] -> Parser [Char]
string [Char]
"..." forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [Char] -> Parser [Char]
string [Char]
"Ellipsis")
        ]
    boolP :: Parser TH.Exp
    boolP :: Parser Exp
boolP = do
      Q Exp
sBool <-
        forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
          [ [|STrue|] forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Char] -> Parser [Char]
string [Char]
"True",
            [|SFalse|] forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Char] -> Parser [Char]
string [Char]
"False"
          ]
      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceBool $sBool|]
    indexP :: Parser TH.Exp
    indexP :: Parser Exp
indexP =
      forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
        [ do
            Integer
index <- forall e s (m :: * -> *) a.
(MonadParsec e s m, Token s ~ Char, Num a) =>
m () -> m a -> m a
L.signed Parser ()
sc forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> Parser a
lexeme forall e s (m :: * -> *) a.
(MonadParsec e s m, Token s ~ Char, Num a) =>
m a
L.decimal
            let con :: Q Exp
con = if Integer
index forall a. Ord a => a -> a -> Bool
< Integer
0 then [|SNegativeIndex|] else [|SIndex|]
                nat :: Q Type
nat = forall (m :: * -> *). Quote m => m TyLit -> m Type
TH.litT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => Integer -> m TyLit
TH.numTyLit forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs Integer
index
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|$con @($nat)|],
          Name -> Exp
TH.VarE forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Name
TH.mkName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Parser a -> Parser a
lexeme (forall (m :: * -> *) open close a.
Applicative m =>
m open -> m close -> m a -> m a
between (Char -> Parser Char
char Char
'{') (Char -> Parser Char
char Char
'}') (forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
some forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
m (Token s)
M.alphaNumChar))
        ]
    sliceP :: Parser TH.Exp
    sliceP :: Parser Exp
sliceP =
      forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
try forall a b. (a -> b) -> a -> b
$
        [ do
            Exp
from <- Parser Exp
indexP
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':'
            Exp
upTo <- Parser Exp
indexP
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':'
            Exp
step <- Parser Exp
indexP
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceFromUpToWithStep $(pure from) $(pure upTo) $(pure step)|],
          do
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':'
            Exp
upTo <- Parser Exp
indexP
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':'
            Exp
step <- Parser Exp
indexP
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceUpToWithStep $(pure upTo) $(pure step)|],
          do
            Exp
from <- Parser Exp
indexP
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
char Char
':'
            Exp
step <- Parser Exp
indexP
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceFromWithStep $(pure from) $(pure step)|],
          do
            Exp
from <- Parser Exp
indexP
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':'
            Exp
upTo <- Parser Exp
indexP
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceFromUpTo $(pure from) $(pure upTo)|],
          do
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
char Char
':'
            Exp
step <- Parser Exp
indexP
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceWithStep $(pure step)|],
          do
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':'
            Exp
upTo <- Parser Exp
indexP
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':'
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceUpTo $(pure upTo)|],
          do
            Exp
from <- Parser Exp
indexP
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Char -> Parser Char
char Char
':' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional (Char -> Parser Char
char Char
':')
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceFrom $(pure from)|],
          do
            Exp
at <- Parser Exp
indexP
            forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceAt $(pure at)|],
          forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [|SSliceAll|] forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
char Char
':' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional (Char -> Parser Char
char Char
':')
        ]

-- | Generate a slice from a [python compatible expression](https://pytorch.org/cppdocs/notes/tensor_indexing.html).
-- When you take the odd-numberPed element of tensor with `tensor[1::2]` in python,
-- you can write `tensor ! [slice|1::2|]` in hasktorch.
slice :: QuasiQuoter
slice :: QuasiQuoter
slice =
  QuasiQuoter
    { quoteExp :: [Char] -> Q Exp
quoteExp = [Char] -> Q Exp
parseSlice,
      quotePat :: [Char] -> Q Pat
quotePat = forall {b} {a}. b -> Q a
notHandled,
      quoteType :: [Char] -> Q Type
quoteType = forall {b} {a}. b -> Q a
notHandled,
      quoteDec :: [Char] -> Q [Dec]
quoteDec = forall {b} {a}. b -> Q a
notHandled
    }
  where
    notHandled :: b -> Q a
notHandled = forall a b. a -> b -> a
const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ [Char]
"'slice' quasiquoter can only be used as an expression."