{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module Torch.GraduallyTyped.Index.Type where import Data.Kind (Type) import Data.Maybe (fromJust) import Data.Proxy (Proxy (..)) import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..)) import GHC.TypeLits (KnownNat, Nat, SomeNat (..), natVal, someNatVal) import Torch.GraduallyTyped.Prelude (IsChecked (..)) data Index (index :: Type) where UncheckedIndex :: forall index. Index index Index :: forall index. index -> Index index NegativeIndex :: forall index. index -> Index index deriving (Int -> Index index -> ShowS forall index. Show index => Int -> Index index -> ShowS forall index. Show index => [Index index] -> ShowS forall index. Show index => Index index -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [Index index] -> ShowS $cshowList :: forall index. Show index => [Index index] -> ShowS show :: Index index -> String $cshow :: forall index. Show index => Index index -> String showsPrec :: Int -> Index index -> ShowS $cshowsPrec :: forall index. Show index => Int -> Index index -> ShowS Show) data SIndex (index :: Index Nat) where SUncheckedIndex :: Integer -> SIndex 'UncheckedIndex SIndex :: forall index. KnownNat index => SIndex ('Index index) SNegativeIndex :: forall index. KnownNat index => SIndex ('NegativeIndex index) deriving stock instance Show (SIndex (index :: Index Nat)) type instance Sing = SIndex instance KnownNat index => SingI ('Index index) where sing :: Sing ('Index index) sing = forall (index :: Nat). KnownNat index => SIndex ('Index index) SIndex instance KnownNat index => SingI ('NegativeIndex index) where sing :: Sing ('NegativeIndex index) sing = forall (index :: Nat). KnownNat index => SIndex ('NegativeIndex index) SNegativeIndex type family IndexF (index :: Index Nat) :: Nat where IndexF ('Index index) = index IndexF ('NegativeIndex index) = index newtype DemotedIndex = DemotedIndex Integer instance SingKind (Index Nat) where type Demote (Index Nat) = IsChecked DemotedIndex fromSing :: forall (a :: Index Nat). Sing a -> Demote (Index Nat) fromSing (SUncheckedIndex Integer index) = forall a. a -> IsChecked a Unchecked forall a b. (a -> b) -> a -> b $ Integer -> DemotedIndex DemotedIndex Integer index fromSing (SIndex a SIndex :: SIndex index) = forall a. a -> IsChecked a Checked forall b c a. (b -> c) -> (a -> b) -> a -> c . Integer -> DemotedIndex DemotedIndex forall b c a. (b -> c) -> (a -> b) -> a -> c . forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Integer natVal forall a b. (a -> b) -> a -> b $ forall {k} (t :: k). Proxy t Proxy @(IndexF index) fromSing (SIndex a SNegativeIndex :: SIndex index) = forall a. a -> IsChecked a Checked forall b c a. (b -> c) -> (a -> b) -> a -> c . Integer -> DemotedIndex DemotedIndex forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. Num a => a -> a negate forall b c a. (b -> c) -> (a -> b) -> a -> c . forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Integer natVal forall a b. (a -> b) -> a -> b $ forall {k} (t :: k). Proxy t Proxy @(IndexF index) toSing :: Demote (Index Nat) -> SomeSing (Index Nat) toSing (Unchecked (DemotedIndex Integer index)) = forall k (a :: k). Sing a -> SomeSing k SomeSing forall a b. (a -> b) -> a -> b $ Integer -> SIndex 'UncheckedIndex SUncheckedIndex Integer index toSing (Checked (DemotedIndex Integer index)) = if Integer index forall a. Ord a => a -> a -> Bool < Integer 0 then case forall a. HasCallStack => Maybe a -> a fromJust forall a b. (a -> b) -> a -> b $ Integer -> Maybe SomeNat someNatVal forall a b. (a -> b) -> a -> b $ forall a. Num a => a -> a negate Integer index of SomeNat (Proxy n _ :: Proxy index) -> forall k (a :: k). Sing a -> SomeSing k SomeSing (forall (index :: Nat). KnownNat index => SIndex ('NegativeIndex index) SNegativeIndex @index) else case forall a. HasCallStack => Maybe a -> a fromJust forall a b. (a -> b) -> a -> b $ Integer -> Maybe SomeNat someNatVal Integer index of SomeNat (Proxy n _ :: Proxy index) -> forall k (a :: k). Sing a -> SomeSing k SomeSing (forall (index :: Nat). KnownNat index => SIndex ('Index index) SIndex @index)