{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.Index.Class where

import Data.Kind (Constraint)
import Data.Type.Equality (type (==))
import GHC.TypeLits (CmpNat, Nat, Symbol, TypeError)
import Torch.GraduallyTyped.Index.Type (Index (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name, Size (..))
import Type.Errors.Pretty (type (<>))

type family IndexOutOfBound (idx :: Nat) (dim :: Dim (Name Symbol) (Size Nat)) where
  IndexOutOfBound idx dim =
    TypeError ("Out of bound index " <> idx <> " for dimension " <> dim)

type family InRangeImplF (idx :: Index Nat) (dim :: Dim (Name Symbol) (Size Nat)) :: Bool where
  InRangeImplF 'UncheckedIndex _ = 'True
  InRangeImplF _ ('Dim _ 'UncheckedSize) = 'True
  InRangeImplF ('Index idx) ('Dim _ ('Size index)) = CmpNat idx index == 'LT

type family InRangeCheckF (idx :: Index Nat) (dim :: Dim (Name Symbol) (Size Nat)) (ok :: Bool) :: Constraint where
  InRangeCheckF _ _ 'True = ()
  InRangeCheckF ('Index idx) dim _ = IndexOutOfBound idx dim

type InRangeF idx dim = InRangeCheckF idx dim (InRangeImplF idx dim)