{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.Shape.Class where
import Control.Exception (Exception (..))
import Control.Monad.Catch (MonadThrow (throwM))
import Data.Singletons (Sing, SingKind (..))
import Data.Typeable (Typeable)
import GHC.TypeLits (Symbol, TypeError, type (+), type (-))
import GHC.TypeNats (Nat)
import Torch.GraduallyTyped.Prelude (Fst, LiftTimesMaybe, MapMaybe, PrependMaybe, Reverse, Snd, forgetIsChecked)
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim (..), SName (..), SSelectDim (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Type.Errors.Pretty (type (%), type (<>))
import Unsafe.Coerce (unsafeCoerce)
type family AddSizeF (size :: Size Nat) (size' :: Size Nat) :: Size Nat where
AddSizeF ('Size size) ('Size size') = 'Size (size + size')
AddSizeF size size' = size <+> size'
type family AddDimF (dim :: Dim (Name Symbol) (Size Nat)) (dim' :: Dim (Name Symbol) (Size Nat)) :: Dim (Name Symbol) (Size Nat) where
AddDimF ('Dim name size) ('Dim name' size') = 'Dim (name <+> name') (AddSizeF size size')
type family BroadcastSizeF (size :: Size Nat) (size' :: Size Nat) :: Maybe (Size Nat) where
BroadcastSizeF 'UncheckedSize _ = 'Just 'UncheckedSize
BroadcastSizeF _ 'UncheckedSize = 'Just 'UncheckedSize
BroadcastSizeF ('Size size) ('Size size) = 'Just ('Size size)
BroadcastSizeF ('Size size) ('Size 1) = 'Just ('Size size)
BroadcastSizeF ('Size 1) ('Size size) = 'Just ('Size size)
BroadcastSizeF ('Size _) ('Size _) = 'Nothing
type family BroadcastDimF (dim :: Dim (Name Symbol) (Size Nat)) (dim' :: Dim (Name Symbol) (Size Nat)) :: Maybe (Dim (Name Symbol) (Size Nat)) where
BroadcastDimF ('Dim name size) ('Dim name' size') = MapMaybe ('Dim (name <+> name')) (BroadcastSizeF size size')
type family NumelDimF (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe Nat where
NumelDimF ('Dim _ 'UncheckedSize) = 'Nothing
NumelDimF ('Dim _ ('Size size)) = 'Just size
type family BroadcastDimsCheckF (dims :: [Dim (Name Symbol) (Size Nat)]) (dims' :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
BroadcastDimsCheckF dims dims' 'Nothing =
TypeError
( "Cannot broadcast the dimensions"
% ""
% " '" <> dims <> "' and '" <> dims' <> "'."
% ""
% "You may need to extend, squeeze, or unsqueeze the dimensions manually."
)
BroadcastDimsCheckF _ _ ('Just dims) = Reverse dims
type family BroadcastDimsImplF (reversedDims :: [Dim (Name Symbol) (Size Nat)]) (reversedDims' :: [Dim (Name Symbol) (Size Nat)]) :: Maybe [Dim (Name Symbol) (Size Nat)] where
BroadcastDimsImplF '[] reversedDims = 'Just reversedDims
BroadcastDimsImplF reversedDims '[] = 'Just reversedDims
BroadcastDimsImplF (dim ': reversedDims) (dim' ': reversedDims') = PrependMaybe (BroadcastDimF dim dim') (BroadcastDimsImplF reversedDims reversedDims')
type BroadcastDimsF dims dims' = BroadcastDimsCheckF dims dims' (BroadcastDimsImplF (Reverse dims) (Reverse dims'))
type family BroadcastShapesF (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
BroadcastShapesF shape shape = shape
BroadcastShapesF ('Shape dims) ('Shape dims') = 'Shape (BroadcastDimsF dims dims')
BroadcastShapesF shape shape' = shape <+> shape'
type family NumelDimsF (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe Nat where
NumelDimsF '[] = 'Just 1
NumelDimsF (dim ': dims) = LiftTimesMaybe (NumelDimF dim) (NumelDimsF dims)
type family NumelF (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Maybe Nat where
NumelF 'UncheckedShape = 'Nothing
NumelF ('Shape dims) = NumelDimsF dims
type family GetDimAndIndexByNameF (index :: Nat) (result :: (Maybe (Dim (Name Symbol) (Size Nat)), Maybe Nat)) (name :: Symbol) (dims :: [Dim (Name Symbol) (Size Nat)]) :: (Maybe (Dim (Name Symbol) (Size Nat)), Maybe Nat) where
GetDimAndIndexByNameF _ result _ '[] = result
GetDimAndIndexByNameF index _ name ('Dim 'UncheckedName _ ': dims) = GetDimAndIndexByNameF (index + 1) '( 'Just ('Dim 'UncheckedName 'UncheckedSize), 'Nothing) name dims
GetDimAndIndexByNameF index _ name ('Dim ('Name name) size ': _) = '( 'Just ('Dim ('Name name) size), 'Just index)
GetDimAndIndexByNameF index result name ('Dim ('Name _) _ ': dims) = GetDimAndIndexByNameF (index + 1) result name dims
type family GetDimByNameF (name :: Symbol) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe (Dim (Name Symbol) (Size Nat)) where
GetDimByNameF name dims = Fst (GetDimAndIndexByNameF 0 '( 'Nothing, 'Nothing) name dims)
type family GetIndexByNameF (name :: Symbol) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe Nat where
GetIndexByNameF name dims = Snd (GetDimAndIndexByNameF 0 '( 'Nothing, 'Nothing) name dims)
type family GetDimByIndexF (index :: Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe (Dim (Name Symbol) (Size Nat)) where
GetDimByIndexF 0 (h ': _) = 'Just h
GetDimByIndexF index (_ ': t) = GetDimByIndexF (index - 1) t
GetDimByIndexF _ _ = 'Nothing
type family GetDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe (Dim (Name Symbol) (Size Nat)) where
GetDimImplF ('ByName name) dims = GetDimByNameF name dims
GetDimImplF ('ByIndex index) dims = GetDimByIndexF index dims
type GetDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
"Cannot return the first dimension matching"
% ""
% " '" <> by <> "'"
% ""
% "in the shape"
% ""
% " '" <> dims <> "'."
% ""
type family GetDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe (Dim (Name Symbol) (Size Nat))) :: Dim (Name Symbol) (Size Nat) where
GetDimCheckF by dims 'Nothing = TypeError (GetDimErrorMessage by dims)
GetDimCheckF _ _ ('Just dim) = dim
type family GetDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Dim (Name Symbol) (Size Nat) where
GetDimF 'UncheckedSelectDim _ = 'Dim 'UncheckedName 'UncheckedSize
GetDimF _ 'UncheckedShape = 'Dim 'UncheckedName 'UncheckedSize
GetDimF ('SelectDim by) ('Shape dims) = GetDimCheckF by dims (GetDimImplF by dims)
type family (!) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (_k :: k) :: Dim (Name Symbol) (Size Nat) where
(!) shape (index :: Nat) = GetDimF ('SelectDim ('ByIndex index)) shape
(!) shape (name :: Symbol) = GetDimF ('SelectDim ('ByName name)) shape
sGetDimFromShape ::
forall selectDim shape dim m.
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim ->
SShape shape ->
m (SDim dim)
sGetDimFromShape :: forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (SUncheckedSelectDim By String Integer
by) (SUncheckedShape [Dim String Integer]
dims) = Integer
-> [Dim String Integer]
-> m (SDim ('Dim 'UncheckedName 'UncheckedSize))
go Integer
0 [Dim String Integer]
dims
where
go :: Integer
-> [Dim String Integer]
-> m (SDim ('Dim 'UncheckedName 'UncheckedSize))
go Integer
_ [] = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> [Dim String Integer] -> GetDimError
GetDimErrorWithDims By String Integer
by [Dim String Integer]
dims
go Integer
index (Dim String
name Integer
size : [Dim String Integer]
dims') =
let dim' :: SDim ('Dim 'UncheckedName 'UncheckedSize)
dim' = forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (String -> SName 'UncheckedName
SUncheckedName String
name) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)
in case By String Integer
by of
ByName String
name' | String
name forall a. Eq a => a -> a -> Bool
== String
name' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SDim ('Dim 'UncheckedName 'UncheckedSize)
dim'
ByIndex Integer
index' | Integer
index forall a. Eq a => a -> a -> Bool
== Integer
index' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SDim ('Dim 'UncheckedName 'UncheckedSize)
dim'
By String Integer
_ -> Integer
-> [Dim String Integer]
-> m (SDim ('Dim 'UncheckedName 'UncheckedSize))
go (Integer
index forall a. Num a => a -> a -> a
+ Integer
1) [Dim String Integer]
dims'
sGetDimFromShape (SSelectDim SBy by
by) (SUncheckedShape [Dim String Integer]
dims) = forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (By String Integer -> SSelectDim 'UncheckedSelectDim
SUncheckedSelectDim forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by) ([Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape [Dim String Integer]
dims)
sGetDimFromShape (SUncheckedSelectDim By String Integer
by) (SShape SList dims
dims) =
let dims' :: [Dim String Integer]
dims' = (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SList dims
dims
in forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (By String Integer -> SSelectDim 'UncheckedSelectDim
SUncheckedSelectDim By String Integer
by) ([Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape [Dim String Integer]
dims')
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByName) (SShape SList dims
SNil) =
let by' :: Demote (By Symbol Nat)
by' = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
in forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> GetDimError
GetDimError Demote (By Symbol Nat)
by'
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByName) (SShape (SCons dim :: Sing n1
dim@(SDim (SUncheckedName String
name) SSize size
_) Sing n2
dims)) =
let ByName String
name' = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
in if String
name forall a. Eq a => a -> a -> Bool
== String
name' then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) Sing n1
dim) else forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim SBy by
by) (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape Sing n2
dims)
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByName) (SShape (SCons dim :: Sing n1
dim@(SDim SName name
SName SSize size
_) Sing n2
dims)) =
let ByName String
name' = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
Dim String
name Integer
_size = (\(Dim IsChecked String
name'' IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name'') (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing Sing n1
dim
in if String
name forall a. Eq a => a -> a -> Bool
== String
name' then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) Sing n1
dim) else forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim SBy by
by) (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape Sing n2
dims)
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByIndex) (SShape SList dims
dims) =
forall {a} (dims :: [a]). Integer -> SList dims -> m (SDim dim)
go Integer
0 SList dims
dims
where
by' :: Demote (By Symbol Nat)
by'@(ByIndex Integer
index') = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
dims' :: [Dim String Integer]
dims' = (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SList dims
dims
go :: forall dims. Integer -> SList dims -> m (SDim dim)
go :: forall {a} (dims :: [a]). Integer -> SList dims -> m (SDim dim)
go Integer
_ SList dims
SNil = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> [Dim String Integer] -> GetDimError
GetDimErrorWithDims Demote (By Symbol Nat)
by' [Dim String Integer]
dims'
go Integer
index (SCons Sing n1
dim Sing n2
dims'') =
if Integer
index' forall a. Eq a => a -> a -> Bool
== Integer
index then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(Sing _) @(SDim dim) Sing n1
dim) else forall {a} (dims :: [a]). Integer -> SList dims -> m (SDim dim)
go (Integer
index forall a. Num a => a -> a -> a
+ Integer
1) Sing n2
dims''
data GetDimError
= GetDimError {GetDimError -> By String Integer
gdeBy :: By String Integer}
| GetDimErrorWithDims {GetDimError -> By String Integer
gdewdBy :: By String Integer, GetDimError -> [Dim String Integer]
gdewdDims :: [Dim String Integer]}
deriving stock (Int -> GetDimError -> ShowS
[GetDimError] -> ShowS
GetDimError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GetDimError] -> ShowS
$cshowList :: [GetDimError] -> ShowS
show :: GetDimError -> String
$cshow :: GetDimError -> String
showsPrec :: Int -> GetDimError -> ShowS
$cshowsPrec :: Int -> GetDimError -> ShowS
Show, Typeable)
instance Exception GetDimError where
displayException :: GetDimError -> String
displayException GetDimError {By String Integer
gdeBy :: By String Integer
gdeBy :: GetDimError -> By String Integer
..} =
String
"Cannot return the first dimension matching `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
gdeBy
forall a. Semigroup a => a -> a -> a
<> String
"`."
displayException GetDimErrorWithDims {[Dim String Integer]
By String Integer
gdewdDims :: [Dim String Integer]
gdewdBy :: By String Integer
gdewdDims :: GetDimError -> [Dim String Integer]
gdewdBy :: GetDimError -> By String Integer
..} =
String
"Cannot return the first dimension matching `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
gdewdBy
forall a. Semigroup a => a -> a -> a
<> String
"` in the shape `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Dim String Integer]
gdewdDims
forall a. Semigroup a => a -> a -> a
<> String
"`."
type family ReplaceDimByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
ReplaceDimByIndexF ('Just 0) (_ ': t) dim = 'Just (dim ': t)
ReplaceDimByIndexF ('Just index) (h ': t) dim = PrependMaybe ('Just h) (ReplaceDimByIndexF ('Just (index - 1)) t dim)
ReplaceDimByIndexF _ _ _ = 'Nothing
type family ReplaceDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
ReplaceDimImplF ('ByName name) dims dim = ReplaceDimByIndexF (GetIndexByNameF name dims) dims dim
ReplaceDimImplF ('ByIndex index) dims dim = ReplaceDimByIndexF ('Just index) dims dim
type family ReplaceDimNameByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (name :: Name Symbol) :: Maybe [Dim (Name Symbol) (Size Nat)] where
ReplaceDimNameByIndexF ('Just 0) ('Dim _ size ': t) name' = 'Just ('Dim name' size ': t)
ReplaceDimNameByIndexF ('Just index) (h ': t) name' = PrependMaybe ('Just h) (ReplaceDimNameByIndexF ('Just (index - 1)) t name')
ReplaceDimNameByIndexF _ _ _ = 'Nothing
type family ReplaceDimNameImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (name' :: Name Symbol) :: Maybe [Dim (Name Symbol) (Size Nat)] where
ReplaceDimNameImplF ('ByName name) dims name' = ReplaceDimNameByIndexF (GetIndexByNameF name dims) dims name'
ReplaceDimNameImplF ('ByIndex index) dims name' = ReplaceDimNameByIndexF ('Just index) dims name'
type family ReplaceDimSizeByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (size' :: Size Nat) :: Maybe [Dim (Name Symbol) (Size Nat)] where
ReplaceDimSizeByIndexF ('Just 0) ('Dim name _ ': t) size' = 'Just ('Dim name size' ': t)
ReplaceDimSizeByIndexF ('Just index) (h ': t) size' = PrependMaybe ('Just h) (ReplaceDimSizeByIndexF ('Just (index - 1)) t size')
ReplaceDimSizeByIndexF _ _ _ = 'Nothing
type family ReplaceDimSizeImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (size' :: Size Nat) :: Maybe [Dim (Name Symbol) (Size Nat)] where
ReplaceDimSizeImplF ('ByName name) dims size' = ReplaceDimSizeByIndexF (GetIndexByNameF name dims) dims size'
ReplaceDimSizeImplF ('ByIndex index) dims size' = ReplaceDimSizeByIndexF ('Just index) dims size'
type ReplaceDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) =
"Cannot replace the first dimension matching"
% ""
% " '" <> by <> "'"
% ""
% "in the shape"
% ""
% " '" <> dims <> "'"
% ""
% "with"
% ""
% " '" <> dim <> "'."
% ""
type family ReplaceDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
ReplaceDimCheckF by dims dim 'Nothing = TypeError (ReplaceDimErrorMessage by dims dim)
ReplaceDimCheckF _ _ _ ('Just dims) = dims
type family ReplaceDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Shape [Dim (Name Symbol) (Size Nat)] where
ReplaceDimF 'UncheckedSelectDim _ _ = 'UncheckedShape
ReplaceDimF _ 'UncheckedShape _ = 'UncheckedShape
ReplaceDimF ('SelectDim by) ('Shape dims) dim = 'Shape (ReplaceDimCheckF by dims dim (ReplaceDimImplF by dims dim))
type family InsertDimByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
InsertDimByIndexF ('Just 0) dims dim = 'Just (dim ': dims)
InsertDimByIndexF ('Just index) (h ': t) dim = PrependMaybe ('Just h) (InsertDimByIndexF ('Just (index - 1)) t dim)
InsertDimByIndexF _ _ _ = 'Nothing
type family InsertDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
InsertDimImplF ('ByName name) dims dim = InsertDimByIndexF (GetIndexByNameF name dims) dims dim
InsertDimImplF ('ByIndex index) dims dim = InsertDimByIndexF ('Just index) dims dim
type InsertDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) =
"Cannot insert the dimension"
% ""
% " '" <> dim <> "'"
% ""
% "before the first dimension matching"
% ""
% " '" <> by <> "'"
% ""
% "in the shape"
% ""
% " '" <> dims <> "'."
% ""
type family InsertDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
InsertDimCheckF by dims dim 'Nothing = TypeError (InsertDimErrorMessage by dims dim)
InsertDimCheckF _ _ _ ('Just dims) = dims
type family InsertDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Shape [Dim (Name Symbol) (Size Nat)] where
InsertDimF 'UncheckedSelectDim _ _ = 'UncheckedShape
InsertDimF _ 'UncheckedShape _ = 'UncheckedShape
InsertDimF ('SelectDim by) ('Shape dims) dim = 'Shape (InsertDimCheckF by dims dim (InsertDimImplF by dims dim))
type family PrependDimF (dim :: Dim (Name Symbol) (Size Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
PrependDimF dim shape = InsertDimF ('SelectDim ('ByIndex 0)) shape dim
type family RemoveDimByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe [Dim (Name Symbol) (Size Nat)] where
RemoveDimByIndexF ('Just 0) (dim ': dims) = 'Just dims
RemoveDimByIndexF ('Just index) (h ': t) = PrependMaybe ('Just h) (RemoveDimByIndexF ('Just (index - 1)) t)
RemoveDimByIndexF _ _ = 'Nothing
type family RemoveDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe [Dim (Name Symbol) (Size Nat)] where
RemoveDimImplF ('ByName name) dims = RemoveDimByIndexF (GetIndexByNameF name dims) dims
RemoveDimImplF ('ByIndex index) dims = RemoveDimByIndexF ('Just index) dims
type RemoveDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
"Cannot remove the dimension by"
% ""
% " '" <> by <> "'"
% ""
% "in the shape"
% ""
% " '" <> dims <> "'."
% ""
type family RemoveDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
RemoveDimCheckF by dims 'Nothing = TypeError (RemoveDimErrorMessage by dims)
RemoveDimCheckF _ _ ('Just dims) = dims
type family RemoveDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
RemoveDimF 'UncheckedSelectDim _ = 'UncheckedShape
RemoveDimF _ 'UncheckedShape = 'UncheckedShape
RemoveDimF ('SelectDim by) ('Shape dims) = 'Shape (RemoveDimCheckF by dims (RemoveDimImplF by dims))
data UnifyNameError = UnifyNameError {UnifyNameError -> String
uneExpect :: String, UnifyNameError -> String
uneActual :: String}
deriving stock (Int -> UnifyNameError -> ShowS
[UnifyNameError] -> ShowS
UnifyNameError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnifyNameError] -> ShowS
$cshowList :: [UnifyNameError] -> ShowS
show :: UnifyNameError -> String
$cshow :: UnifyNameError -> String
showsPrec :: Int -> UnifyNameError -> ShowS
$cshowsPrec :: Int -> UnifyNameError -> ShowS
Show)
instance Exception UnifyNameError where
displayException :: UnifyNameError -> String
displayException UnifyNameError {String
uneActual :: String
uneExpect :: String
uneActual :: UnifyNameError -> String
uneExpect :: UnifyNameError -> String
..} =
String
"The supplied dimensions must be the same, "
forall a. Semigroup a => a -> a -> a
<> String
"but dimensions with different names were found: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show String
uneExpect
forall a. Semigroup a => a -> a -> a
<> String
" and "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show String
uneActual
forall a. Semigroup a => a -> a -> a
<> String
"."
sUnifyName ::
forall m name name'.
MonadThrow m =>
SName name ->
SName name' ->
m (SName (name <+> name'))
sUnifyName :: forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName (SUncheckedName String
name) (SUncheckedName String
name') | String
name forall a. Eq a => a -> a -> Bool
== String
name' = forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> SName 'UncheckedName
SUncheckedName String
name)
sUnifyName (SUncheckedName String
"*") (SUncheckedName String
name') = forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> SName 'UncheckedName
SUncheckedName String
name')
sUnifyName (SUncheckedName String
name) (SUncheckedName String
"*") = forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> SName 'UncheckedName
SUncheckedName String
name)
sUnifyName name :: SName name
name@SName name
SName (SUncheckedName String
name') = forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName (String -> SName 'UncheckedName
SUncheckedName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) (String -> SName 'UncheckedName
SUncheckedName String
name')
sUnifyName (SUncheckedName String
name) name' :: SName name'
name'@SName name'
SName = forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName (String -> SName 'UncheckedName
SUncheckedName String
name) (String -> SName 'UncheckedName
SUncheckedName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name')
sUnifyName name :: SName name
name@SName name
SName name' :: SName name'
name'@SName name'
SName | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) forall a. Eq a => a -> a -> Bool
== forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name') = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SName name) @(SName (name <+> name')) SName name
name)
sUnifyName name :: SName name
name@SName name
SName name' :: SName name'
name'@SName name'
SName | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) forall a. Eq a => a -> a -> Bool
== String
"*" = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SName name') @(SName (name <+> name')) SName name'
name')
sUnifyName name :: SName name
name@SName name
SName name' :: SName name'
name'@SName name'
SName | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name') forall a. Eq a => a -> a -> Bool
== String
"*" = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SName name) @(SName (name <+> name')) SName name
name)
sUnifyName SName name
name SName name'
name' = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> String -> UnifyNameError
UnifyNameError (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name)) (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name'))
data UnifySizeError = UnifySizeError {UnifySizeError -> Integer
useExpect :: Integer, UnifySizeError -> Integer
useActual :: Integer}
deriving stock (Int -> UnifySizeError -> ShowS
[UnifySizeError] -> ShowS
UnifySizeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnifySizeError] -> ShowS
$cshowList :: [UnifySizeError] -> ShowS
show :: UnifySizeError -> String
$cshow :: UnifySizeError -> String
showsPrec :: Int -> UnifySizeError -> ShowS
$cshowsPrec :: Int -> UnifySizeError -> ShowS
Show)
instance Exception UnifySizeError where
displayException :: UnifySizeError -> String
displayException UnifySizeError {Integer
useActual :: Integer
useExpect :: Integer
useActual :: UnifySizeError -> Integer
useExpect :: UnifySizeError -> Integer
..} =
String
"The supplied dimensions must be the same, "
forall a. Semigroup a => a -> a -> a
<> String
"but dimensions with different sizes were found: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Integer
useExpect
forall a. Semigroup a => a -> a -> a
<> String
" and "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Integer
useActual
forall a. Semigroup a => a -> a -> a
<> String
"."
sUnifySize ::
forall m size size'.
MonadThrow m =>
SSize size ->
SSize size' ->
m (SSize (size <+> size'))
sUnifySize :: forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize (SUncheckedSize Integer
size) (SUncheckedSize Integer
size') | Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)
sUnifySize size :: SSize size
size@SSize size
SSize (SUncheckedSize Integer
size') = forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize (Integer -> SSize 'UncheckedSize
SUncheckedSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size')
sUnifySize (SUncheckedSize Integer
size) size' :: SSize size'
size'@SSize size'
SSize = forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size) (Integer -> SSize 'UncheckedSize
SUncheckedSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size'
size')
sUnifySize size :: SSize size
size@SSize size
SSize size' :: SSize size'
size'@SSize size'
SSize | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size) forall a. Eq a => a -> a -> Bool
== forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size'
size') = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SSize size) @(SSize (size <+> size')) SSize size
size)
sUnifySize SSize size
size SSize size'
size' = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> UnifySizeError
UnifySizeError (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size)) (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size'
size'))
sUnifyDim ::
forall m dim dim'.
MonadThrow m =>
SDim dim ->
SDim dim' ->
m (SDim (dim <+> dim'))
sUnifyDim :: forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
(dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim (SDim SName name
name SSize size
size) (SDim SName name
name' SSize size
size') = do
SDim
('Dim (Unify (Name Symbol) name name) (Unify (Size Nat) size size))
dim <- forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName SName name
name SName name
name' forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize SSize size
size SSize size
size'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim (dim <+> dim')) SDim
('Dim (Unify (Name Symbol) name name) (Unify (Size Nat) size size))
dim