{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.Layout where
import Data.Kind (Type)
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Data.Singletons.TH (genSingletons)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen (kSparse, kStrided)
import qualified Torch.Internal.Type as ATen (Layout)
data LayoutType
=
Dense
|
Sparse
deriving (Int -> LayoutType -> ShowS
[LayoutType] -> ShowS
LayoutType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LayoutType] -> ShowS
$cshowList :: [LayoutType] -> ShowS
show :: LayoutType -> String
$cshow :: LayoutType -> String
showsPrec :: Int -> LayoutType -> ShowS
$cshowsPrec :: Int -> LayoutType -> ShowS
Show, LayoutType -> LayoutType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LayoutType -> LayoutType -> Bool
$c/= :: LayoutType -> LayoutType -> Bool
== :: LayoutType -> LayoutType -> Bool
$c== :: LayoutType -> LayoutType -> Bool
Eq)
genSingletons [''LayoutType]
deriving stock instance Show (SLayoutType (layoutType :: LayoutType))
class KnownLayoutType (layoutType :: LayoutType) where
layoutTypeVal :: LayoutType
instance KnownLayoutType 'Dense where
layoutTypeVal :: LayoutType
layoutTypeVal = LayoutType
Dense
instance KnownLayoutType 'Sparse where
layoutTypeVal :: LayoutType
layoutTypeVal = LayoutType
Sparse
instance Castable LayoutType ATen.Layout where
cast :: forall r. LayoutType -> (Layout -> IO r) -> IO r
cast LayoutType
Dense Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kStrided
cast LayoutType
Sparse Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kSparse
uncast :: forall r. Layout -> (LayoutType -> IO r) -> IO r
uncast Layout
x LayoutType -> IO r
f
| Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kStrided = LayoutType -> IO r
f LayoutType
Dense
| Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kSparse = LayoutType -> IO r
f LayoutType
Sparse
data Layout (layoutType :: Type) where
UncheckedLayout :: forall layoutType. Layout layoutType
Layout :: forall layoutType. layoutType -> Layout layoutType
deriving (Int -> Layout layoutType -> ShowS
forall layoutType.
Show layoutType =>
Int -> Layout layoutType -> ShowS
forall layoutType. Show layoutType => [Layout layoutType] -> ShowS
forall layoutType. Show layoutType => Layout layoutType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Layout layoutType] -> ShowS
$cshowList :: forall layoutType. Show layoutType => [Layout layoutType] -> ShowS
show :: Layout layoutType -> String
$cshow :: forall layoutType. Show layoutType => Layout layoutType -> String
showsPrec :: Int -> Layout layoutType -> ShowS
$cshowsPrec :: forall layoutType.
Show layoutType =>
Int -> Layout layoutType -> ShowS
Show)
data SLayout (layout :: Layout LayoutType) where
SUncheckedLayout :: LayoutType -> SLayout 'UncheckedLayout
SLayout :: forall layoutType. SLayoutType layoutType -> SLayout ('Layout layoutType)
deriving stock instance Show (SLayout (layout :: Layout LayoutType))
type instance Sing = SLayout
instance SingI layoutType => SingI ('Layout (layoutType :: LayoutType)) where
sing :: Sing ('Layout layoutType)
sing = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @layoutType
instance SingKind (Layout LayoutType) where
type Demote (Layout LayoutType) = IsChecked LayoutType
fromSing :: forall (a :: Layout LayoutType).
Sing a -> Demote (Layout LayoutType)
fromSing (SUncheckedLayout LayoutType
layoutType) = forall a. a -> IsChecked a
Unchecked LayoutType
layoutType
fromSing (SLayout SLayoutType layoutType
layoutType) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SLayoutType layoutType
layoutType
toSing :: Demote (Layout LayoutType) -> SomeSing (Layout LayoutType)
toSing (Unchecked LayoutType
layoutType) = forall k (a :: k). Sing a -> SomeSing k
SomeSing (LayoutType -> SLayout 'UncheckedLayout
SUncheckedLayout LayoutType
layoutType)
toSing (Checked LayoutType
layoutType) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing LayoutType
layoutType forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout
class KnownLayout (layout :: Layout LayoutType) where
layoutVal :: Layout LayoutType
instance KnownLayout 'UncheckedLayout where
layoutVal :: Layout LayoutType
layoutVal = forall layoutType. Layout layoutType
UncheckedLayout
instance (KnownLayoutType layoutType) => KnownLayout ('Layout layoutType) where
layoutVal :: Layout LayoutType
layoutVal = forall layoutType. layoutType -> Layout layoutType
Layout (forall (layoutType :: LayoutType).
KnownLayoutType layoutType =>
LayoutType
layoutTypeVal @layoutType)
type GetLayouts :: k -> [Layout LayoutType]
type family GetLayouts f where
GetLayouts (a :: Layout LayoutType) = '[a]
GetLayouts (f g) = Concat (GetLayouts f) (GetLayouts g)
GetLayouts _ = '[]