{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.Layout where

import Data.Kind (Type)
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Data.Singletons.TH (genSingletons)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen (kSparse, kStrided)
import qualified Torch.Internal.Type as ATen (Layout)

-- | Data type that represents the memory layout of a tensor.
data LayoutType
  = -- | The memory layout of the tensor is dense (strided).
    Dense
  | -- | The memory layout of the tensor is sparse.
    Sparse
  deriving (Int -> LayoutType -> ShowS
[LayoutType] -> ShowS
LayoutType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LayoutType] -> ShowS
$cshowList :: [LayoutType] -> ShowS
show :: LayoutType -> String
$cshow :: LayoutType -> String
showsPrec :: Int -> LayoutType -> ShowS
$cshowsPrec :: Int -> LayoutType -> ShowS
Show, LayoutType -> LayoutType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LayoutType -> LayoutType -> Bool
$c/= :: LayoutType -> LayoutType -> Bool
== :: LayoutType -> LayoutType -> Bool
$c== :: LayoutType -> LayoutType -> Bool
Eq)

genSingletons [''LayoutType]

deriving stock instance Show (SLayoutType (layoutType :: LayoutType))

class KnownLayoutType (layoutType :: LayoutType) where
  layoutTypeVal :: LayoutType

instance KnownLayoutType 'Dense where
  layoutTypeVal :: LayoutType
layoutTypeVal = LayoutType
Dense

instance KnownLayoutType 'Sparse where
  layoutTypeVal :: LayoutType
layoutTypeVal = LayoutType
Sparse

instance Castable LayoutType ATen.Layout where
  cast :: forall r. LayoutType -> (Layout -> IO r) -> IO r
cast LayoutType
Dense Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kStrided
  cast LayoutType
Sparse Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kSparse

  uncast :: forall r. Layout -> (LayoutType -> IO r) -> IO r
uncast Layout
x LayoutType -> IO r
f
    | Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kStrided = LayoutType -> IO r
f LayoutType
Dense
    | Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kSparse = LayoutType -> IO r
f LayoutType
Sparse

-- | Data type to represent whether or not the tensor's memory layout is checked, that is, known to the compiler.
data Layout (layoutType :: Type) where
  -- | The tensor's memory layout is unknown to the compiler.
  UncheckedLayout :: forall layoutType. Layout layoutType
  -- | The tensor's memory layout is known to the compiler.
  Layout :: forall layoutType. layoutType -> Layout layoutType
  deriving (Int -> Layout layoutType -> ShowS
forall layoutType.
Show layoutType =>
Int -> Layout layoutType -> ShowS
forall layoutType. Show layoutType => [Layout layoutType] -> ShowS
forall layoutType. Show layoutType => Layout layoutType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Layout layoutType] -> ShowS
$cshowList :: forall layoutType. Show layoutType => [Layout layoutType] -> ShowS
show :: Layout layoutType -> String
$cshow :: forall layoutType. Show layoutType => Layout layoutType -> String
showsPrec :: Int -> Layout layoutType -> ShowS
$cshowsPrec :: forall layoutType.
Show layoutType =>
Int -> Layout layoutType -> ShowS
Show)

data SLayout (layout :: Layout LayoutType) where
  SUncheckedLayout :: LayoutType -> SLayout 'UncheckedLayout
  SLayout :: forall layoutType. SLayoutType layoutType -> SLayout ('Layout layoutType)

deriving stock instance Show (SLayout (layout :: Layout LayoutType))

type instance Sing = SLayout

instance SingI layoutType => SingI ('Layout (layoutType :: LayoutType)) where
  sing :: Sing ('Layout layoutType)
sing = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @layoutType

instance SingKind (Layout LayoutType) where
  type Demote (Layout LayoutType) = IsChecked LayoutType
  fromSing :: forall (a :: Layout LayoutType).
Sing a -> Demote (Layout LayoutType)
fromSing (SUncheckedLayout LayoutType
layoutType) = forall a. a -> IsChecked a
Unchecked LayoutType
layoutType
  fromSing (SLayout SLayoutType layoutType
layoutType) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SLayoutType layoutType
layoutType
  toSing :: Demote (Layout LayoutType) -> SomeSing (Layout LayoutType)
toSing (Unchecked LayoutType
layoutType) = forall k (a :: k). Sing a -> SomeSing k
SomeSing (LayoutType -> SLayout 'UncheckedLayout
SUncheckedLayout LayoutType
layoutType)
  toSing (Checked LayoutType
layoutType) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing LayoutType
layoutType forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout

class KnownLayout (layout :: Layout LayoutType) where
  layoutVal :: Layout LayoutType

instance KnownLayout 'UncheckedLayout where
  layoutVal :: Layout LayoutType
layoutVal = forall layoutType. Layout layoutType
UncheckedLayout

instance (KnownLayoutType layoutType) => KnownLayout ('Layout layoutType) where
  layoutVal :: Layout LayoutType
layoutVal = forall layoutType. layoutType -> Layout layoutType
Layout (forall (layoutType :: LayoutType).
KnownLayoutType layoutType =>
LayoutType
layoutTypeVal @layoutType)

-- >>> :kind! GetLayouts ('Layout 'Dense)
-- GetLayouts ('Layout 'Dense) :: [Layout LayoutType]
-- = '[ 'Layout 'Dense]
-- >>> :kind! GetLayouts '[ 'Layout 'Sparse, 'Layout 'Dense]
-- GetLayouts '[ 'Layout 'Sparse, 'Layout 'Dense] :: [Layout
--                                                      LayoutType]
-- = '[ 'Layout 'Sparse, 'Layout 'Dense]
-- >>> :kind! GetLayouts ('Just ('Layout 'Dense))
-- GetLayouts ('Just ('Layout 'Dense)) :: [Layout LayoutType]
-- = '[ 'Layout 'Dense]
type GetLayouts :: k -> [Layout LayoutType]
type family GetLayouts f where
  GetLayouts (a :: Layout LayoutType) = '[a]
  GetLayouts (f g) = Concat (GetLayouts f) (GetLayouts g)
  GetLayouts _ = '[]