{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE RankNTypes #-}
module Torch.Typed.Auxiliary where
import qualified Data.Int as I
import Data.Kind
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Data.Constraint
import Unsafe.Coerce (unsafeCoerce)
natValI :: forall n. KnownNat n => Int
natValI :: forall (n :: Nat). KnownNat n => Int
natValI = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @n
natValInt16 :: forall n. KnownNat n => I.Int16
natValInt16 :: forall (n :: Nat). KnownNat n => Int16
natValInt16 = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @n
type family Fst (t :: (a, b)) :: a where
Torch.Typed.Auxiliary.Fst '(x, _) = x
type family Snd (t :: (a, b)) :: b where
Torch.Typed.Auxiliary.Snd '(_, x) = x
type family Fst3 (t :: (a, b, c)) :: a where
Fst3 '(x, _, _) = x
type family Snd3 (t :: (a, b, c)) :: b where
Snd3 '(_, x, _) = x
type family Trd3 (t :: (a, b, c)) :: c where
Trd3 '(_, _, x) = x
type family DimOutOfBoundCheckImpl (shape :: [a]) (dim :: Nat) (xs :: [a]) (n :: Nat) :: Constraint where
DimOutOfBoundCheckImpl shape dim '[] _ = DimOutOfBound shape dim
DimOutOfBoundCheckImpl _ _ _ 0 = ()
DimOutOfBoundCheckImpl shape dim (_ ': xs) n = DimOutOfBoundCheckImpl shape dim xs (n - 1)
type DimOutOfBoundCheck shape dim = DimOutOfBoundCheckImpl shape dim shape dim
type family DimOutOfBound (shape :: [a]) (dim :: Nat) where
DimOutOfBound shape dim =
TypeError
( Text "Out of bound dimension: "
:<>: ShowType dim
:<>: Text " (the tensor is only "
:<>: ShowType (ListLength shape)
:<>: Text "D)"
)
type family IndexOutOfBound (shape :: [a]) (dim :: Nat) (idx :: Nat) where
IndexOutOfBound shape dim idx =
TypeError
( Text "Out of bound index "
:<>: ShowType idx
:<>: Text " for dimension "
:<>: ShowType dim
:<>: Text " (the tensor shape is "
:<>: ShowType shape
:<>: Text ")"
)
type family AppendToMaybe (h :: a) (mt :: Maybe [a]) :: Maybe [a] where
AppendToMaybe h Nothing = Nothing
AppendToMaybe h (Just t) = Just (h : t)
type family AppendToMaybe' (h :: Maybe a) (mt :: Maybe [a]) :: Maybe [a] where
AppendToMaybe' Nothing _ = Nothing
AppendToMaybe' _ Nothing = Nothing
AppendToMaybe' (Just h) (Just t) = Just (h : t)
type family MaybePrepend (mh :: Maybe a) (t :: [a]) :: [a] where
MaybePrepend Nothing t = t
MaybePrepend (Just h) t = h : t
type family LastDim (l :: [a]) :: Nat where
LastDim (_ ': '[]) = 0
LastDim (_ ': t) = 1 + LastDim t
type family Product (xs :: [Nat]) :: Nat where
Product '[] = 1
Product (x ': xs) = x GHC.TypeLits.* Product xs
type family BackwardsImpl (last :: Nat) (n :: Nat) :: Nat where
BackwardsImpl last n = last - n
type Backwards l n = BackwardsImpl (LastDim l) n
type IsSuffixOf xs ys = CheckIsSuffixOf xs ys (IsSuffixOfImpl xs ys (DropLengthMaybe xs ys))
type family CheckIsSuffixOf (xs :: [a]) (ys :: [a]) (result :: Bool) :: Constraint where
CheckIsSuffixOf _ _ 'True = ()
CheckIsSuffixOf xs ys 'False = TypeError (ShowType xs :<>: Text " is not a suffix of " :<>: ShowType ys)
type family IsSuffixOfImpl (xs :: [a]) (ys :: [a]) (mDelta :: Maybe [b]) :: Bool where
IsSuffixOfImpl xs ys ('Just delta) = xs == DropLength delta ys
IsSuffixOfImpl _ _ 'Nothing = 'False
type family DropLengthMaybe (xs :: [a]) (ys :: [b]) :: Maybe [b] where
DropLengthMaybe '[] ys = 'Just ys
DropLengthMaybe _ '[] = 'Nothing
DropLengthMaybe (_ : xs) (_ : ys) = DropLengthMaybe xs ys
type family DropLength (xs :: [a]) (ys :: [b]) :: [b] where
DropLength '[] ys = ys
DropLength _ '[] = '[]
DropLength (_ : xs) (_ : ys) = DropLength xs ys
type family Init (xs :: [a]) :: [a] where
Init '[] = TypeError (Text "Init of empty list.")
Init (x ': '[]) = '[]
Init (x ': xs) = x ': Init xs
type family Last (xs :: [a]) :: a where
Last '[] = TypeError (Text "Last of empty list.")
Last (x ': '[]) = x
Last (x ': xs) = Last xs
type family InsertImpl (n :: Nat) (x :: a) (l :: [a]) :: Maybe [a] where
InsertImpl 0 x l = Just (x ': l)
InsertImpl n x '[] = Nothing
InsertImpl n x (h ': t) = AppendToMaybe h (InsertImpl (n - 1) x t)
type family CheckInsert (n :: Nat) (x :: a) (l :: [a]) (result :: Maybe [a]) :: [a] where
CheckInsert _ _ _ (Just xs) = xs
CheckInsert n x l Nothing = DimOutOfBound l n
type family Insert (n :: Nat) (x :: a) (l :: [a]) :: [a] where
Insert n x l = CheckInsert n x l (InsertImpl n x l)
type family RemoveImpl (l :: [a]) (n :: Nat) :: Maybe [a] where
RemoveImpl (h ': t) 0 = Just t
RemoveImpl (h ': t) n = AppendToMaybe h (RemoveImpl t (n - 1))
RemoveImpl _ _ = Nothing
type family CheckRemove (l :: [a]) (n :: Nat) (result :: Maybe [a]) :: [a] where
CheckRemove l n Nothing = DimOutOfBound l n
CheckRemove _ _ (Just result) = result
type Remove l n = CheckRemove l n (RemoveImpl l n)
type family IndexImpl (l :: [a]) (n :: Nat) :: Maybe a where
IndexImpl (h ': t) 0 = Just h
IndexImpl (h ': t) n = IndexImpl t (n - 1)
IndexImpl _ _ = Nothing
type family CheckIndex (l :: [a]) (n :: Nat) (result :: Maybe a) :: a where
CheckIndex l n Nothing = DimOutOfBound l n
CheckIndex _ _ (Just result) = result
type Index l n = CheckIndex l n (IndexImpl l n)
type family InRangeCheck (shape :: [Nat]) (dim :: Nat) (idx :: Nat) (ok :: Ordering) :: Constraint where
InRangeCheck _ _ _ 'LT = ()
InRangeCheck shape dim idx _ = IndexOutOfBound shape dim idx
type InRange shape dim idx = InRangeCheck shape dim idx (CmpNat idx (Index shape dim))
type family ReverseImpl (l :: [a]) (acc :: [a]) :: [a] where
ReverseImpl '[] acc = acc
ReverseImpl (h ': t) acc = ReverseImpl t (h ': acc)
type Reverse l = ReverseImpl l '[]
type family (dim :: Nat) (shape :: [Nat]) :: Maybe Nat where
0 (h ': _) = Just h
dim (_ ': t) = ExtractDim (dim - 1) t
_ _ = Nothing
type family ReplaceDim (dim :: Nat) (shape :: [Nat]) (n :: Nat) :: Maybe [Nat] where
ReplaceDim 0 (_ ': t) n = Just (n ': t)
ReplaceDim dim (h ': t) n = AppendToMaybe h (ReplaceDim (dim - 1) t n)
ReplaceDim _ _ _ = Nothing
type family If c t e where
If 'True t e = t
If 'False t e = e
type family AllDimsPositive (shape :: [Nat]) :: Constraint where
AllDimsPositive '[] = ()
AllDimsPositive (x ': xs) = If (1 <=? x) (AllDimsPositive xs) (TypeError (Text "Expected positive dimension but got " :<>: ShowType x :<>: Text "!"))
type family IsAtLeast (n :: Nat) (m :: Nat) (cmp :: Ordering) :: Constraint where
IsAtLeast n m LT =
TypeError
( Text "Expected a dimension of size at least "
:<>: ShowType n
:<>: Text " but got "
:<>: ShowType m
:<>: Text "!"
)
IsAtLeast _ _ _ = ()
type (>=) (n :: Nat) (m :: Nat) = (IsAtLeast n m (CmpNat n m), KnownNat (n - m))
type family CmpDType (dtype :: D.DType) (dtype' :: D.DType) :: Ordering where
CmpDType dtype dtype = 'EQ
CmpDType D.Bool D.UInt8 = 'LT
CmpDType D.Bool D.Int8 = 'LT
CmpDType D.Bool D.Int16 = 'LT
CmpDType D.Bool D.Int32 = 'LT
CmpDType D.Bool D.Int64 = 'LT
CmpDType D.Bool D.Half = 'LT
CmpDType D.Bool D.Float = 'LT
CmpDType D.Bool D.Double = 'LT
CmpDType D.UInt8 D.Int8 = 'LT
CmpDType D.UInt8 D.Int16 = 'LT
CmpDType D.UInt8 D.Int32 = 'LT
CmpDType D.UInt8 D.Int64 = 'LT
CmpDType D.UInt8 D.Half = 'LT
CmpDType D.UInt8 D.Float = 'LT
CmpDType D.UInt8 D.Double = 'LT
CmpDType D.Int8 D.Int16 = 'LT
CmpDType D.Int8 D.Int32 = 'LT
CmpDType D.Int8 D.Int64 = 'LT
CmpDType D.Int8 D.Half = 'LT
CmpDType D.Int8 D.Float = 'LT
CmpDType D.Int8 D.Double = 'LT
CmpDType D.Int16 D.Int32 = 'LT
CmpDType D.Int16 D.Int64 = 'LT
CmpDType D.Int16 D.Half = 'LT
CmpDType D.Int16 D.Float = 'LT
CmpDType D.Int16 D.Double = 'LT
CmpDType D.Int32 D.Int64 = 'LT
CmpDType D.Int32 D.Half = 'LT
CmpDType D.Int32 D.Float = 'LT
CmpDType D.Int32 D.Double = 'LT
CmpDType D.Int64 D.Half = 'LT
CmpDType D.Int64 D.Float = 'LT
CmpDType D.Int64 D.Double = 'LT
CmpDType D.Half D.Float = 'LT
CmpDType D.Half D.Double = 'LT
CmpDType D.Float D.Double = 'LT
CmpDType _ _ = 'GT
type family DTypePromotionImpl (dtype :: D.DType) (dtype' :: D.DType) (ord :: Ordering) :: D.DType where
DTypePromotionImpl D.UInt8 D.Int8 _ = D.Int16
DTypePromotionImpl D.Int8 D.UInt8 _ = D.Int16
DTypePromotionImpl dtype _ EQ = dtype
DTypePromotionImpl _ dtype LT = dtype
DTypePromotionImpl dtype _ GT = dtype
type DTypePromotion dtype dtype' = DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')
type family DTypeIsFloatingPoint (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
DTypeIsFloatingPoint _ 'D.Half = ()
DTypeIsFloatingPoint _ 'D.Float = ()
DTypeIsFloatingPoint _ 'D.Double = ()
DTypeIsFloatingPoint '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
type family DTypeIsIntegral (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
DTypeIsIntegral _ 'D.Bool = ()
DTypeIsIntegral _ 'D.UInt8 = ()
DTypeIsIntegral _ 'D.Int8 = ()
DTypeIsIntegral _ 'D.Int16 = ()
DTypeIsIntegral _ 'D.Int32 = ()
DTypeIsIntegral _ 'D.Int64 = ()
DTypeIsIntegral '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
type family DTypeIsNotHalf (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
DTypeIsNotHalf '(deviceType, _) D.Half = UnsupportedDTypeForDevice deviceType D.Half
DTypeIsNotHalf _ _ = ()
type family DTypeIsNotBool (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
DTypeIsNotBool '(deviceType, _) D.Bool = UnsupportedDTypeForDevice deviceType D.Bool
DTypeIsNotBool _ _ = ()
type family UnsupportedDTypeForDevice (deviceType :: D.DeviceType) (dtype :: D.DType) :: Constraint where
UnsupportedDTypeForDevice deviceType dtype =
TypeError
( Text "This operation does not support "
:<>: ShowType dtype
:<>: Text " tensors on devices of type "
:<>: ShowType deviceType
:<>: Text "."
)
type family StandardFloatingPointDTypeValidation (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
StandardFloatingPointDTypeValidation '( 'D.CPU, 0) dtype =
( DTypeIsFloatingPoint '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
StandardFloatingPointDTypeValidation '( 'D.CUDA, deviceIndex) dtype = DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype
StandardFloatingPointDTypeValidation '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
type family StandardDTypeValidation (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
StandardDTypeValidation '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
StandardDTypeValidation '( 'D.CUDA, deviceIndex) dtype = DTypeIsNotBool '( 'D.CUDA, deviceIndex) dtype
StandardDTypeValidation '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
unsafeConstraint :: forall c a. (c => a) -> a
unsafeConstraint :: forall (c :: Constraint) a. (c => a) -> a
unsafeConstraint = forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (forall (b :: Constraint). Dict b
dummyDict @c)
where
dummyDict :: forall b. Dict b
dummyDict :: forall (b :: Constraint). Dict b
dummyDict = forall a b. a -> b
unsafeCoerce (forall (a :: Constraint). a => Dict a
Dict :: Dict ())
withNat ::
Int ->
( forall n.
KnownNat n =>
Proxy n ->
r
) ->
r
withNat :: forall r.
Int -> (forall (n :: Nat). KnownNat n => Proxy n -> r) -> r
withNat Int
i forall (n :: Nat). KnownNat n => Proxy n -> r
f = case Integer -> Maybe SomeNat
someNatVal (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) of
Maybe SomeNat
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"Negative Number in withNat!"
(Just (SomeNat Proxy n
p)) -> forall (n :: Nat). KnownNat n => Proxy n -> r
f Proxy n
p
forEachNat :: forall n a. KnownNat n => (forall i. KnownNat i => Proxy i -> a) -> [a]
forEachNat :: forall (n :: Nat) a.
KnownNat n =>
(forall (i :: Nat). KnownNat i => Proxy i -> a) -> [a]
forEachNat forall (i :: Nat). KnownNat i => Proxy i -> a
func = forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> forall r.
Int -> (forall (n :: Nat). KnownNat n => Proxy n -> r) -> r
withNat Int
i forall (i :: Nat). KnownNat i => Proxy i -> a
func) [Int
0 .. (forall (n :: Nat). KnownNat n => Int
natValI @n forall a. Num a => a -> a -> a
-Int
1)]