{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Torch.GraduallyTyped.Index.Type where

import Data.Kind (Type)
import Data.Maybe (fromJust)
import Data.Proxy (Proxy (..))
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..))
import GHC.TypeLits (KnownNat, Nat, SomeNat (..), natVal, someNatVal)
import Torch.GraduallyTyped.Prelude (IsChecked (..))

data Index (index :: Type) where
  UncheckedIndex :: forall index. Index index
  Index :: forall index. index -> Index index
  NegativeIndex :: forall index. index -> Index index
  deriving (Int -> Index index -> ShowS
forall index. Show index => Int -> Index index -> ShowS
forall index. Show index => [Index index] -> ShowS
forall index. Show index => Index index -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Index index] -> ShowS
$cshowList :: forall index. Show index => [Index index] -> ShowS
show :: Index index -> String
$cshow :: forall index. Show index => Index index -> String
showsPrec :: Int -> Index index -> ShowS
$cshowsPrec :: forall index. Show index => Int -> Index index -> ShowS
Show)

data SIndex (index :: Index Nat) where
  SUncheckedIndex :: Integer -> SIndex 'UncheckedIndex
  SIndex :: forall index. KnownNat index => SIndex ('Index index)
  SNegativeIndex :: forall index. KnownNat index => SIndex ('NegativeIndex index)

deriving stock instance Show (SIndex (index :: Index Nat))

type instance Sing = SIndex

instance KnownNat index => SingI ('Index index) where
  sing :: Sing ('Index index)
sing = forall (index :: Nat). KnownNat index => SIndex ('Index index)
SIndex

instance KnownNat index => SingI ('NegativeIndex index) where
  sing :: Sing ('NegativeIndex index)
sing = forall (index :: Nat).
KnownNat index =>
SIndex ('NegativeIndex index)
SNegativeIndex

type family IndexF (index :: Index Nat) :: Nat where
  IndexF ('Index index) = index
  IndexF ('NegativeIndex index) = index

newtype DemotedIndex = DemotedIndex Integer

instance SingKind (Index Nat) where
  type Demote (Index Nat) = IsChecked DemotedIndex
  fromSing :: forall (a :: Index Nat). Sing a -> Demote (Index Nat)
fromSing (SUncheckedIndex Integer
index) = forall a. a -> IsChecked a
Unchecked forall a b. (a -> b) -> a -> b
$ Integer -> DemotedIndex
DemotedIndex Integer
index
  fromSing (SIndex a
SIndex :: SIndex index) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> DemotedIndex
DemotedIndex forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(IndexF index)
  fromSing (SIndex a
SNegativeIndex :: SIndex index) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> DemotedIndex
DemotedIndex forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(IndexF index)
  toSing :: Demote (Index Nat) -> SomeSing (Index Nat)
toSing (Unchecked (DemotedIndex Integer
index)) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall a b. (a -> b) -> a -> b
$ Integer -> SIndex 'UncheckedIndex
SUncheckedIndex Integer
index
  toSing (Checked (DemotedIndex Integer
index)) =
    if Integer
index forall a. Ord a => a -> a -> Bool
< Integer
0
      then case forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ Integer -> Maybe SomeNat
someNatVal forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
negate Integer
index of
        SomeNat (Proxy n
_ :: Proxy index) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (index :: Nat).
KnownNat index =>
SIndex ('NegativeIndex index)
SNegativeIndex @index)
      else case forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ Integer -> Maybe SomeNat
someNatVal Integer
index of
        SomeNat (Proxy n
_ :: Proxy index) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (index :: Nat). KnownNat index => SIndex ('Index index)
SIndex @index)