{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.Shape.Class where

import Control.Exception (Exception (..))
import Control.Monad.Catch (MonadThrow (throwM))
import Data.Singletons (Sing, SingKind (..))
import Data.Typeable (Typeable)
import GHC.TypeLits (Symbol, TypeError, type (+), type (-))
import GHC.TypeNats (Nat)
import Torch.GraduallyTyped.Prelude (Fst, LiftTimesMaybe, MapMaybe, PrependMaybe, Reverse, Snd, forgetIsChecked)
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim (..), SName (..), SSelectDim (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import Type.Errors.Pretty (type (%), type (<>))
import Unsafe.Coerce (unsafeCoerce)

-- $setup
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped

type family AddSizeF (size :: Size Nat) (size' :: Size Nat) :: Size Nat where
  AddSizeF ('Size size) ('Size size') = 'Size (size + size')
  AddSizeF size size' = size <+> size'

type family AddDimF (dim :: Dim (Name Symbol) (Size Nat)) (dim' :: Dim (Name Symbol) (Size Nat)) :: Dim (Name Symbol) (Size Nat) where
  AddDimF ('Dim name size) ('Dim name' size') = 'Dim (name <+> name') (AddSizeF size size')

type family BroadcastSizeF (size :: Size Nat) (size' :: Size Nat) :: Maybe (Size Nat) where
  BroadcastSizeF 'UncheckedSize _ = 'Just 'UncheckedSize
  BroadcastSizeF _ 'UncheckedSize = 'Just 'UncheckedSize
  BroadcastSizeF ('Size size) ('Size size) = 'Just ('Size size)
  BroadcastSizeF ('Size size) ('Size 1) = 'Just ('Size size)
  BroadcastSizeF ('Size 1) ('Size size) = 'Just ('Size size)
  BroadcastSizeF ('Size _) ('Size _) = 'Nothing

type family BroadcastDimF (dim :: Dim (Name Symbol) (Size Nat)) (dim' :: Dim (Name Symbol) (Size Nat)) :: Maybe (Dim (Name Symbol) (Size Nat)) where
  BroadcastDimF ('Dim name size) ('Dim name' size') = MapMaybe ('Dim (name <+> name')) (BroadcastSizeF size size')

type family NumelDimF (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe Nat where
  NumelDimF ('Dim _ 'UncheckedSize) = 'Nothing
  NumelDimF ('Dim _ ('Size size)) = 'Just size

type family BroadcastDimsCheckF (dims :: [Dim (Name Symbol) (Size Nat)]) (dims' :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
  BroadcastDimsCheckF dims dims' 'Nothing =
    TypeError
      ( "Cannot broadcast the dimensions"
          % ""
          % "    '" <> dims <> "' and '" <> dims' <> "'."
          % ""
          % "You may need to extend, squeeze, or unsqueeze the dimensions manually."
      )
  BroadcastDimsCheckF _ _ ('Just dims) = Reverse dims

type family BroadcastDimsImplF (reversedDims :: [Dim (Name Symbol) (Size Nat)]) (reversedDims' :: [Dim (Name Symbol) (Size Nat)]) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  BroadcastDimsImplF '[] reversedDims = 'Just reversedDims
  BroadcastDimsImplF reversedDims '[] = 'Just reversedDims
  BroadcastDimsImplF (dim ': reversedDims) (dim' ': reversedDims') = PrependMaybe (BroadcastDimF dim dim') (BroadcastDimsImplF reversedDims reversedDims')

type BroadcastDimsF dims dims' = BroadcastDimsCheckF dims dims' (BroadcastDimsImplF (Reverse dims) (Reverse dims'))

type family BroadcastShapesF (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  BroadcastShapesF shape shape = shape
  BroadcastShapesF ('Shape dims) ('Shape dims') = 'Shape (BroadcastDimsF dims dims')
  BroadcastShapesF shape shape' = shape <+> shape'

type family NumelDimsF (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe Nat where
  NumelDimsF '[] = 'Just 1
  NumelDimsF (dim ': dims) = LiftTimesMaybe (NumelDimF dim) (NumelDimsF dims)

type family NumelF (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Maybe Nat where
  NumelF 'UncheckedShape = 'Nothing
  NumelF ('Shape dims) = NumelDimsF dims

type family GetDimAndIndexByNameF (index :: Nat) (result :: (Maybe (Dim (Name Symbol) (Size Nat)), Maybe Nat)) (name :: Symbol) (dims :: [Dim (Name Symbol) (Size Nat)]) :: (Maybe (Dim (Name Symbol) (Size Nat)), Maybe Nat) where
  GetDimAndIndexByNameF _ result _ '[] = result
  GetDimAndIndexByNameF index _ name ('Dim 'UncheckedName _ ': dims) = GetDimAndIndexByNameF (index + 1) '( 'Just ('Dim 'UncheckedName 'UncheckedSize), 'Nothing) name dims
  GetDimAndIndexByNameF index _ name ('Dim ('Name name) size ': _) = '( 'Just ('Dim ('Name name) size), 'Just index)
  GetDimAndIndexByNameF index result name ('Dim ('Name _) _ ': dims) = GetDimAndIndexByNameF (index + 1) result name dims

type family GetDimByNameF (name :: Symbol) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe (Dim (Name Symbol) (Size Nat)) where
  GetDimByNameF name dims = Fst (GetDimAndIndexByNameF 0 '( 'Nothing, 'Nothing) name dims)

type family GetIndexByNameF (name :: Symbol) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe Nat where
  GetIndexByNameF name dims = Snd (GetDimAndIndexByNameF 0 '( 'Nothing, 'Nothing) name dims)

type family GetDimByIndexF (index :: Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe (Dim (Name Symbol) (Size Nat)) where
  GetDimByIndexF 0 (h ': _) = 'Just h
  GetDimByIndexF index (_ ': t) = GetDimByIndexF (index - 1) t
  GetDimByIndexF _ _ = 'Nothing

type family GetDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe (Dim (Name Symbol) (Size Nat)) where
  GetDimImplF ('ByName name) dims = GetDimByNameF name dims
  GetDimImplF ('ByIndex index) dims = GetDimByIndexF index dims

type GetDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
  "Cannot return the first dimension matching"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "in the shape"
    % ""
    % "    '" <> dims <> "'."
    % ""

type family GetDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe (Dim (Name Symbol) (Size Nat))) :: Dim (Name Symbol) (Size Nat) where
  GetDimCheckF by dims 'Nothing = TypeError (GetDimErrorMessage by dims)
  GetDimCheckF _ _ ('Just dim) = dim

type family GetDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Dim (Name Symbol) (Size Nat) where
  GetDimF 'UncheckedSelectDim _ = 'Dim 'UncheckedName 'UncheckedSize
  GetDimF _ 'UncheckedShape = 'Dim 'UncheckedName 'UncheckedSize
  GetDimF ('SelectDim by) ('Shape dims) = GetDimCheckF by dims (GetDimImplF by dims)

type family (!) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (_k :: k) :: Dim (Name Symbol) (Size Nat) where
  (!) shape (index :: Nat) = GetDimF ('SelectDim ('ByIndex index)) shape
  (!) shape (name :: Symbol) = GetDimF ('SelectDim ('ByName name)) shape

-- | Get dimension by index or by name from a shape.
--
-- >>> shape = SShape $ SName @"batch" :&: SSize @8 :|: SUncheckedName "feature" :&: SSize @2 :|: SNil
-- >>> dim = sGetDimFromShape (SSelectDim $ SByName @"batch") shape
-- >>> :type dim
-- dim :: MonadThrow m => m (SDim ('Dim ('Name "batch") ('Size 8)))
-- >>> fromSing <$> dim
-- Dim {dimName = Checked "batch", dimSize = Checked 8}
--
-- >>> dim = sGetDimFromShape (SSelectDim $ SByName @"feature") shape
-- >>> :type dim
-- dim
--   :: MonadThrow m => m (SDim ('Dim 'UncheckedName 'UncheckedSize))
-- >>> fromSing <$> dim
-- Dim {dimName = Unchecked "feature", dimSize = Checked 2}
--
-- >>> dim = sGetDimFromShape (SSelectDim $ SByName @"sequence") shape
-- >>> :type dim
-- dim
--   :: MonadThrow m => m (SDim ('Dim 'UncheckedName 'UncheckedSize))
-- >>> fromSing <$> dim
-- *** Exception: GetDimError {gdeBy = ByName "sequence"}
--
-- >>> dim = sGetDimFromShape (SSelectDim $ SByIndex @0) shape
-- >>> :type dim
-- dim :: MonadThrow m => m (SDim ('Dim ('Name "batch") ('Size 8)))
-- >>> fromSing <$> dim
-- Dim {dimName = Checked "batch", dimSize = Checked 8}
--
-- >>> :type sGetDimFromShape (SSelectDim $ SByIndex @2) shape
-- sGetDimFromShape (SSelectDim $ SByIndex @2) shape
--   :: MonadThrow m => m (SDim (TypeError ...))
sGetDimFromShape ::
  forall selectDim shape dim m.
  (dim ~ GetDimF selectDim shape, MonadThrow m) =>
  SSelectDim selectDim ->
  SShape shape ->
  m (SDim dim)
sGetDimFromShape :: forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (SUncheckedSelectDim By String Integer
by) (SUncheckedShape [Dim String Integer]
dims) = Integer
-> [Dim String Integer]
-> m (SDim ('Dim 'UncheckedName 'UncheckedSize))
go Integer
0 [Dim String Integer]
dims
  where
    go :: Integer
-> [Dim String Integer]
-> m (SDim ('Dim 'UncheckedName 'UncheckedSize))
go Integer
_ [] = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> [Dim String Integer] -> GetDimError
GetDimErrorWithDims By String Integer
by [Dim String Integer]
dims
    go Integer
index (Dim String
name Integer
size : [Dim String Integer]
dims') =
      let dim' :: SDim ('Dim 'UncheckedName 'UncheckedSize)
dim' = forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim (String -> SName 'UncheckedName
SUncheckedName String
name) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)
       in case By String Integer
by of
            ByName String
name' | String
name forall a. Eq a => a -> a -> Bool
== String
name' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SDim ('Dim 'UncheckedName 'UncheckedSize)
dim'
            ByIndex Integer
index' | Integer
index forall a. Eq a => a -> a -> Bool
== Integer
index' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SDim ('Dim 'UncheckedName 'UncheckedSize)
dim'
            By String Integer
_ -> Integer
-> [Dim String Integer]
-> m (SDim ('Dim 'UncheckedName 'UncheckedSize))
go (Integer
index forall a. Num a => a -> a -> a
+ Integer
1) [Dim String Integer]
dims'
sGetDimFromShape (SSelectDim SBy by
by) (SUncheckedShape [Dim String Integer]
dims) = forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (By String Integer -> SSelectDim 'UncheckedSelectDim
SUncheckedSelectDim forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by) ([Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape [Dim String Integer]
dims)
sGetDimFromShape (SUncheckedSelectDim By String Integer
by) (SShape SList dims
dims) =
  let dims' :: [Dim String Integer]
dims' = (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SList dims
dims
   in forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (By String Integer -> SSelectDim 'UncheckedSelectDim
SUncheckedSelectDim By String Integer
by) ([Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape [Dim String Integer]
dims')
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByName) (SShape SList dims
SNil) =
  let by' :: Demote (By Symbol Nat)
by' = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
   in forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> GetDimError
GetDimError Demote (By Symbol Nat)
by'
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByName) (SShape (SCons dim :: Sing n1
dim@(SDim (SUncheckedName String
name) SSize size
_) Sing n2
dims)) =
  let ByName String
name' = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
   in if String
name forall a. Eq a => a -> a -> Bool
== String
name' then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) Sing n1
dim) else forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim SBy by
by) (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape Sing n2
dims)
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByName) (SShape (SCons dim :: Sing n1
dim@(SDim SName name
SName SSize size
_) Sing n2
dims)) =
  let ByName String
name' = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
      Dim String
name Integer
_size = (\(Dim IsChecked String
name'' IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name'') (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing Sing n1
dim
   in if String
name forall a. Eq a => a -> a -> Bool
== String
name' then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) Sing n1
dim) else forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim dim) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim SBy by
by) (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape Sing n2
dims)
sGetDimFromShape (SSelectDim by :: SBy by
by@SBy by
SByIndex) (SShape SList dims
dims) =
  forall {a} (dims :: [a]). Integer -> SList dims -> m (SDim dim)
go Integer
0 SList dims
dims
  where
    by' :: Demote (By Symbol Nat)
by'@(ByIndex Integer
index') = forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SBy by
by
    dims' :: [Dim String Integer]
dims' = (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SList dims
dims
    go :: forall dims. Integer -> SList dims -> m (SDim dim)
    go :: forall {a} (dims :: [a]). Integer -> SList dims -> m (SDim dim)
go Integer
_ SList dims
SNil = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> [Dim String Integer] -> GetDimError
GetDimErrorWithDims Demote (By Symbol Nat)
by' [Dim String Integer]
dims'
    go Integer
index (SCons Sing n1
dim Sing n2
dims'') =
      if Integer
index' forall a. Eq a => a -> a -> Bool
== Integer
index then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(Sing _) @(SDim dim) Sing n1
dim) else forall {a} (dims :: [a]). Integer -> SList dims -> m (SDim dim)
go (Integer
index forall a. Num a => a -> a -> a
+ Integer
1) Sing n2
dims''

data GetDimError
  = GetDimError {GetDimError -> By String Integer
gdeBy :: By String Integer}
  | GetDimErrorWithDims {GetDimError -> By String Integer
gdewdBy :: By String Integer, GetDimError -> [Dim String Integer]
gdewdDims :: [Dim String Integer]}
  deriving stock (Int -> GetDimError -> ShowS
[GetDimError] -> ShowS
GetDimError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GetDimError] -> ShowS
$cshowList :: [GetDimError] -> ShowS
show :: GetDimError -> String
$cshow :: GetDimError -> String
showsPrec :: Int -> GetDimError -> ShowS
$cshowsPrec :: Int -> GetDimError -> ShowS
Show, Typeable)

instance Exception GetDimError where
  displayException :: GetDimError -> String
displayException GetDimError {By String Integer
gdeBy :: By String Integer
gdeBy :: GetDimError -> By String Integer
..} =
    String
"Cannot return the first dimension matching `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
gdeBy
      forall a. Semigroup a => a -> a -> a
<> String
"`."
  displayException GetDimErrorWithDims {[Dim String Integer]
By String Integer
gdewdDims :: [Dim String Integer]
gdewdBy :: By String Integer
gdewdDims :: GetDimError -> [Dim String Integer]
gdewdBy :: GetDimError -> By String Integer
..} =
    String
"Cannot return the first dimension matching `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
gdewdBy
      forall a. Semigroup a => a -> a -> a
<> String
"` in the shape `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Dim String Integer]
gdewdDims
      forall a. Semigroup a => a -> a -> a
<> String
"`."

type family ReplaceDimByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimByIndexF ('Just 0) (_ ': t) dim = 'Just (dim ': t)
  ReplaceDimByIndexF ('Just index) (h ': t) dim = PrependMaybe ('Just h) (ReplaceDimByIndexF ('Just (index - 1)) t dim)
  ReplaceDimByIndexF _ _ _ = 'Nothing

type family ReplaceDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimImplF ('ByName name) dims dim = ReplaceDimByIndexF (GetIndexByNameF name dims) dims dim
  ReplaceDimImplF ('ByIndex index) dims dim = ReplaceDimByIndexF ('Just index) dims dim

type family ReplaceDimNameByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (name :: Name Symbol) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimNameByIndexF ('Just 0) ('Dim _ size ': t) name' = 'Just ('Dim name' size ': t)
  ReplaceDimNameByIndexF ('Just index) (h ': t) name' = PrependMaybe ('Just h) (ReplaceDimNameByIndexF ('Just (index - 1)) t name')
  ReplaceDimNameByIndexF _ _ _ = 'Nothing

type family ReplaceDimNameImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (name' :: Name Symbol) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimNameImplF ('ByName name) dims name' = ReplaceDimNameByIndexF (GetIndexByNameF name dims) dims name'
  ReplaceDimNameImplF ('ByIndex index) dims name' = ReplaceDimNameByIndexF ('Just index) dims name'

type family ReplaceDimSizeByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (size' :: Size Nat) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimSizeByIndexF ('Just 0) ('Dim name _ ': t) size' = 'Just ('Dim name size' ': t)
  ReplaceDimSizeByIndexF ('Just index) (h ': t) size' = PrependMaybe ('Just h) (ReplaceDimSizeByIndexF ('Just (index - 1)) t size')
  ReplaceDimSizeByIndexF _ _ _ = 'Nothing

type family ReplaceDimSizeImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (size' :: Size Nat) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimSizeImplF ('ByName name) dims size' = ReplaceDimSizeByIndexF (GetIndexByNameF name dims) dims size'
  ReplaceDimSizeImplF ('ByIndex index) dims size' = ReplaceDimSizeByIndexF ('Just index) dims size'

type ReplaceDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) =
  "Cannot replace the first dimension matching"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "in the shape"
    % ""
    % "    '" <> dims <> "'"
    % ""
    % "with"
    % ""
    % "    '" <> dim <> "'."
    % ""

type family ReplaceDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimCheckF by dims dim 'Nothing = TypeError (ReplaceDimErrorMessage by dims dim)
  ReplaceDimCheckF _ _ _ ('Just dims) = dims

type family ReplaceDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Shape [Dim (Name Symbol) (Size Nat)] where
  ReplaceDimF 'UncheckedSelectDim _ _ = 'UncheckedShape
  ReplaceDimF _ 'UncheckedShape _ = 'UncheckedShape
  ReplaceDimF ('SelectDim by) ('Shape dims) dim = 'Shape (ReplaceDimCheckF by dims dim (ReplaceDimImplF by dims dim))

type family InsertDimByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  InsertDimByIndexF ('Just 0) dims dim = 'Just (dim ': dims)
  InsertDimByIndexF ('Just index) (h ': t) dim = PrependMaybe ('Just h) (InsertDimByIndexF ('Just (index - 1)) t dim)
  InsertDimByIndexF _ _ _ = 'Nothing

type family InsertDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  InsertDimImplF ('ByName name) dims dim = InsertDimByIndexF (GetIndexByNameF name dims) dims dim
  InsertDimImplF ('ByIndex index) dims dim = InsertDimByIndexF ('Just index) dims dim

type InsertDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) =
  "Cannot insert the dimension"
    % ""
    % "    '" <> dim <> "'"
    % ""
    % "before the first dimension matching"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "in the shape"
    % ""
    % "    '" <> dims <> "'."
    % ""

type family InsertDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
  InsertDimCheckF by dims dim 'Nothing = TypeError (InsertDimErrorMessage by dims dim)
  InsertDimCheckF _ _ _ ('Just dims) = dims

type family InsertDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (dim :: Dim (Name Symbol) (Size Nat)) :: Shape [Dim (Name Symbol) (Size Nat)] where
  InsertDimF 'UncheckedSelectDim _ _ = 'UncheckedShape
  InsertDimF _ 'UncheckedShape _ = 'UncheckedShape
  InsertDimF ('SelectDim by) ('Shape dims) dim = 'Shape (InsertDimCheckF by dims dim (InsertDimImplF by dims dim))

type family PrependDimF (dim :: Dim (Name Symbol) (Size Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  PrependDimF dim shape = InsertDimF ('SelectDim ('ByIndex 0)) shape dim

type family RemoveDimByIndexF (index :: Maybe Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  RemoveDimByIndexF ('Just 0) (dim ': dims) = 'Just dims
  RemoveDimByIndexF ('Just index) (h ': t) = PrependMaybe ('Just h) (RemoveDimByIndexF ('Just (index - 1)) t)
  RemoveDimByIndexF _ _ = 'Nothing

type family RemoveDimImplF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: Maybe [Dim (Name Symbol) (Size Nat)] where
  RemoveDimImplF ('ByName name) dims = RemoveDimByIndexF (GetIndexByNameF name dims) dims
  RemoveDimImplF ('ByIndex index) dims = RemoveDimByIndexF ('Just index) dims

type RemoveDimErrorMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
  "Cannot remove the dimension by"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "in the shape"
    % ""
    % "    '" <> dims <> "'."
    % ""

type family RemoveDimCheckF (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) (result :: Maybe [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
  RemoveDimCheckF by dims 'Nothing = TypeError (RemoveDimErrorMessage by dims)
  RemoveDimCheckF _ _ ('Just dims) = dims

-- >>> type SelectBatch = 'SelectDim ('ByName "batch")
-- >>> type Dims = '[ 'Dim ('Name "batch") ('Size 10), 'Dim ('Name "feature") ('Size 8)]
-- >>> :kind! RemoveDimF SelectBatch ('Shape Dims)
-- RemoveDimF SelectBatch ('Shape Dims) :: Shape
--                                           [Dim (Name Symbol) (Size Nat)]
-- = 'Shape
--     '[ 'Dim ('Name "feature") ('Size 8),
--        'Dim ('Name "anotherFeature") ('Size 12)]
type family RemoveDimF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  RemoveDimF 'UncheckedSelectDim _ = 'UncheckedShape
  RemoveDimF _ 'UncheckedShape = 'UncheckedShape
  RemoveDimF ('SelectDim by) ('Shape dims) = 'Shape (RemoveDimCheckF by dims (RemoveDimImplF by dims))

data UnifyNameError = UnifyNameError {UnifyNameError -> String
uneExpect :: String, UnifyNameError -> String
uneActual :: String}
  deriving stock (Int -> UnifyNameError -> ShowS
[UnifyNameError] -> ShowS
UnifyNameError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnifyNameError] -> ShowS
$cshowList :: [UnifyNameError] -> ShowS
show :: UnifyNameError -> String
$cshow :: UnifyNameError -> String
showsPrec :: Int -> UnifyNameError -> ShowS
$cshowsPrec :: Int -> UnifyNameError -> ShowS
Show)

instance Exception UnifyNameError where
  displayException :: UnifyNameError -> String
displayException UnifyNameError {String
uneActual :: String
uneExpect :: String
uneActual :: UnifyNameError -> String
uneExpect :: UnifyNameError -> String
..} =
    String
"The supplied dimensions must be the same, "
      forall a. Semigroup a => a -> a -> a
<> String
"but dimensions with different names were found: "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show String
uneExpect
      forall a. Semigroup a => a -> a -> a
<> String
" and "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show String
uneActual
      forall a. Semigroup a => a -> a -> a
<> String
"."

sUnifyName ::
  forall m name name'.
  MonadThrow m =>
  SName name ->
  SName name' ->
  m (SName (name <+> name'))
sUnifyName :: forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName (SUncheckedName String
name) (SUncheckedName String
name') | String
name forall a. Eq a => a -> a -> Bool
== String
name' = forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> SName 'UncheckedName
SUncheckedName String
name)
sUnifyName (SUncheckedName String
"*") (SUncheckedName String
name') = forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> SName 'UncheckedName
SUncheckedName String
name')
sUnifyName (SUncheckedName String
name) (SUncheckedName String
"*") = forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> SName 'UncheckedName
SUncheckedName String
name)
sUnifyName name :: SName name
name@SName name
SName (SUncheckedName String
name') = forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName (String -> SName 'UncheckedName
SUncheckedName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) (String -> SName 'UncheckedName
SUncheckedName String
name')
sUnifyName (SUncheckedName String
name) name' :: SName name'
name'@SName name'
SName = forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName (String -> SName 'UncheckedName
SUncheckedName String
name) (String -> SName 'UncheckedName
SUncheckedName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name')
sUnifyName name :: SName name
name@SName name
SName name' :: SName name'
name'@SName name'
SName | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) forall a. Eq a => a -> a -> Bool
== forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name') = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SName name) @(SName (name <+> name')) SName name
name)
sUnifyName name :: SName name
name@SName name
SName name' :: SName name'
name'@SName name'
SName | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) forall a. Eq a => a -> a -> Bool
== String
"*" = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SName name') @(SName (name <+> name')) SName name'
name')
sUnifyName name :: SName name
name@SName name
SName name' :: SName name'
name'@SName name'
SName | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name') forall a. Eq a => a -> a -> Bool
== String
"*" = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SName name) @(SName (name <+> name')) SName name
name)
sUnifyName SName name
name SName name'
name' = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> String -> UnifyNameError
UnifyNameError (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name)) (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name'
name'))

data UnifySizeError = UnifySizeError {UnifySizeError -> Integer
useExpect :: Integer, UnifySizeError -> Integer
useActual :: Integer}
  deriving stock (Int -> UnifySizeError -> ShowS
[UnifySizeError] -> ShowS
UnifySizeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnifySizeError] -> ShowS
$cshowList :: [UnifySizeError] -> ShowS
show :: UnifySizeError -> String
$cshow :: UnifySizeError -> String
showsPrec :: Int -> UnifySizeError -> ShowS
$cshowsPrec :: Int -> UnifySizeError -> ShowS
Show)

instance Exception UnifySizeError where
  displayException :: UnifySizeError -> String
displayException UnifySizeError {Integer
useActual :: Integer
useExpect :: Integer
useActual :: UnifySizeError -> Integer
useExpect :: UnifySizeError -> Integer
..} =
    String
"The supplied dimensions must be the same, "
      forall a. Semigroup a => a -> a -> a
<> String
"but dimensions with different sizes were found: "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Integer
useExpect
      forall a. Semigroup a => a -> a -> a
<> String
" and "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Integer
useActual
      forall a. Semigroup a => a -> a -> a
<> String
"."

sUnifySize ::
  forall m size size'.
  MonadThrow m =>
  SSize size ->
  SSize size' ->
  m (SSize (size <+> size'))
sUnifySize :: forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize (SUncheckedSize Integer
size) (SUncheckedSize Integer
size') | Integer
size forall a. Eq a => a -> a -> Bool
== Integer
size' = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size)
sUnifySize size :: SSize size
size@SSize size
SSize (SUncheckedSize Integer
size') = forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize (Integer -> SSize 'UncheckedSize
SUncheckedSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size) (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size')
sUnifySize (SUncheckedSize Integer
size) size' :: SSize size'
size'@SSize size'
SSize = forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize (Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
size) (Integer -> SSize 'UncheckedSize
SUncheckedSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size'
size')
sUnifySize size :: SSize size
size@SSize size
SSize size' :: SSize size'
size'@SSize size'
SSize | forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size) forall a. Eq a => a -> a -> Bool
== forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size'
size') = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> b
unsafeCoerce @(SSize size) @(SSize (size <+> size')) SSize size
size)
sUnifySize SSize size
size SSize size'
size' = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> UnifySizeError
UnifySizeError (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size)) (forall a. IsChecked a -> a
forgetIsChecked (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size'
size'))

-- | Unify two dimensions.
--
-- >>> dimA = SName @"*" :&: SSize @0
-- >>> dimB = SName @"batch" :&: SSize @0
-- >>> dim = sUnifyDim dimA dimB
-- >>> :type dim
-- dim :: MonadThrow m => m (SDim ('Dim ('Name "batch") ('Size 0)))
-- >>> fromSing <$> dim
-- Dim {dimName = Checked "batch", dimSize = Checked 0}
--
-- >>> dimC = SName @"feature" :&: SSize @0
-- >>> :type sUnifyDim dimB dimC
-- sUnifyDim dimB dimC
--   :: MonadThrow m => m (SDim ('Dim (TypeError ...) ('Size 0)))
--
-- >>> dimD = SUncheckedName "batch" :&: SSize @0
-- >>> dim = sUnifyDim dimA dimD
-- >>> :type dim
-- dim :: MonadThrow m => m (SDim ('Dim 'UncheckedName ('Size 0)))
-- >>> fromSing <$> dim
-- Dim {dimName = Unchecked "batch", dimSize = Checked 0}
--
-- >>> dimE = SUncheckedName "feature" :&: SSize @0
-- >>> dim = sUnifyDim dimB dimE
-- >>> :type dim
-- dim :: MonadThrow m => m (SDim ('Dim 'UncheckedName ('Size 0)))
-- >>> fromSing <$> dim
-- *** Exception: UnifyNameError {uneExpect = "batch", uneActual = "feature"}
sUnifyDim ::
  forall m dim dim'.
  MonadThrow m =>
  SDim dim ->
  SDim dim' ->
  m (SDim (dim <+> dim'))
sUnifyDim :: forall (m :: * -> *) (dim :: Dim (Name Symbol) (Size Nat))
       (dim' :: Dim (Name Symbol) (Size Nat)).
MonadThrow m =>
SDim dim -> SDim dim' -> m (SDim (dim <+> dim'))
sUnifyDim (SDim SName name
name SSize size
size) (SDim SName name
name' SSize size
size') = do
  SDim
  ('Dim (Unify (Name Symbol) name name) (Unify (Size Nat) size size))
dim <- forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
SDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (name :: Name Symbol) (name' :: Name Symbol).
MonadThrow m =>
SName name -> SName name' -> m (SName (name <+> name'))
sUnifyName SName name
name SName name
name' forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) (size :: Size Nat) (size' :: Size Nat).
MonadThrow m =>
SSize size -> SSize size' -> m (SSize (size <+> size'))
sUnifySize SSize size
size SSize size
size'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> b
unsafeCoerce @(SDim _) @(SDim (dim <+> dim')) SDim
  ('Dim (Unify (Name Symbol) name name) (Unify (Size Nat) size size))
dim