{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.Layout where
import Data.Kind (Type)
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Data.Singletons.TH (genSingletons)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen (kSparse, kStrided)
import qualified Torch.Internal.Type as ATen (Layout)
data LayoutType
  = 
    Dense
  | 
    Sparse
  deriving (Int -> LayoutType -> ShowS
[LayoutType] -> ShowS
LayoutType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LayoutType] -> ShowS
$cshowList :: [LayoutType] -> ShowS
show :: LayoutType -> String
$cshow :: LayoutType -> String
showsPrec :: Int -> LayoutType -> ShowS
$cshowsPrec :: Int -> LayoutType -> ShowS
Show, LayoutType -> LayoutType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LayoutType -> LayoutType -> Bool
$c/= :: LayoutType -> LayoutType -> Bool
== :: LayoutType -> LayoutType -> Bool
$c== :: LayoutType -> LayoutType -> Bool
Eq)
genSingletons [''LayoutType]
deriving stock instance Show (SLayoutType (layoutType :: LayoutType))
class KnownLayoutType (layoutType :: LayoutType) where
  layoutTypeVal :: LayoutType
instance KnownLayoutType 'Dense where
  layoutTypeVal :: LayoutType
layoutTypeVal = LayoutType
Dense
instance KnownLayoutType 'Sparse where
  layoutTypeVal :: LayoutType
layoutTypeVal = LayoutType
Sparse
instance Castable LayoutType ATen.Layout where
  cast :: forall r. LayoutType -> (Layout -> IO r) -> IO r
cast LayoutType
Dense Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kStrided
  cast LayoutType
Sparse Layout -> IO r
f = Layout -> IO r
f Layout
ATen.kSparse
  uncast :: forall r. Layout -> (LayoutType -> IO r) -> IO r
uncast Layout
x LayoutType -> IO r
f
    | Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kStrided = LayoutType -> IO r
f LayoutType
Dense
    | Layout
x forall a. Eq a => a -> a -> Bool
== Layout
ATen.kSparse = LayoutType -> IO r
f LayoutType
Sparse
data Layout (layoutType :: Type) where
  
  UncheckedLayout :: forall layoutType. Layout layoutType
  
  Layout :: forall layoutType. layoutType -> Layout layoutType
  deriving (Int -> Layout layoutType -> ShowS
forall layoutType.
Show layoutType =>
Int -> Layout layoutType -> ShowS
forall layoutType. Show layoutType => [Layout layoutType] -> ShowS
forall layoutType. Show layoutType => Layout layoutType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Layout layoutType] -> ShowS
$cshowList :: forall layoutType. Show layoutType => [Layout layoutType] -> ShowS
show :: Layout layoutType -> String
$cshow :: forall layoutType. Show layoutType => Layout layoutType -> String
showsPrec :: Int -> Layout layoutType -> ShowS
$cshowsPrec :: forall layoutType.
Show layoutType =>
Int -> Layout layoutType -> ShowS
Show)
data SLayout (layout :: Layout LayoutType) where
  SUncheckedLayout :: LayoutType -> SLayout 'UncheckedLayout
  SLayout :: forall layoutType. SLayoutType layoutType -> SLayout ('Layout layoutType)
deriving stock instance Show (SLayout (layout :: Layout LayoutType))
type instance Sing = SLayout
instance SingI layoutType => SingI ('Layout (layoutType :: LayoutType)) where
  sing :: Sing ('Layout layoutType)
sing = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @layoutType
instance SingKind (Layout LayoutType) where
  type Demote (Layout LayoutType) = IsChecked LayoutType
  fromSing :: forall (a :: Layout LayoutType).
Sing a -> Demote (Layout LayoutType)
fromSing (SUncheckedLayout LayoutType
layoutType) = forall a. a -> IsChecked a
Unchecked LayoutType
layoutType
  fromSing (SLayout SLayoutType layoutType
layoutType) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SLayoutType layoutType
layoutType
  toSing :: Demote (Layout LayoutType) -> SomeSing (Layout LayoutType)
toSing (Unchecked LayoutType
layoutType) = forall k (a :: k). Sing a -> SomeSing k
SomeSing (LayoutType -> SLayout 'UncheckedLayout
SUncheckedLayout LayoutType
layoutType)
  toSing (Checked LayoutType
layoutType) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing LayoutType
layoutType forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout
class KnownLayout (layout :: Layout LayoutType) where
  layoutVal :: Layout LayoutType
instance KnownLayout 'UncheckedLayout where
  layoutVal :: Layout LayoutType
layoutVal = forall layoutType. Layout layoutType
UncheckedLayout
instance (KnownLayoutType layoutType) => KnownLayout ('Layout layoutType) where
  layoutVal :: Layout LayoutType
layoutVal = forall layoutType. layoutType -> Layout layoutType
Layout (forall (layoutType :: LayoutType).
KnownLayoutType layoutType =>
LayoutType
layoutTypeVal @layoutType)
type GetLayouts :: k -> [Layout LayoutType]
type family GetLayouts f where
  GetLayouts (a :: Layout LayoutType) = '[a]
  GetLayouts (f g) = Concat (GetLayouts f) (GetLayouts g)
  GetLayouts _ = '[]