{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.Shape.Type where
import Data.Bifunctor (Bifunctor (..))
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.TypeLits (KnownNat, KnownSymbol, Nat, SomeNat (..), SomeSymbol (..), Symbol, natVal, someNatVal, someSymbolVal, symbolVal)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..), forgetIsChecked)
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Managed.Cast as ATen ()
import qualified Torch.Internal.Managed.Type.Dimname as ATen (dimname_symbol, fromSymbol_s)
import qualified Torch.Internal.Managed.Type.DimnameList as ATen (dimnameList_at_s, dimnameList_push_back_n, dimnameList_size, newDimnameList)
import qualified Torch.Internal.Managed.Type.IntArray as ATen (intArray_at_s, intArray_push_back_l, intArray_size, newIntArray)
import qualified Torch.Internal.Managed.Type.StdString as ATen (newStdString_s, string_c_str)
import qualified Torch.Internal.Managed.Type.Symbol as ATen (dimname_s, symbol_toUnqualString)
import qualified Torch.Internal.Type as ATen (Dimname, DimnameList, IntArray)
data Size (size :: Type) where
UncheckedSize :: forall size. Size size
Size :: forall size. size -> Size size
deriving (Int -> Size size -> ShowS
forall size. Show size => Int -> Size size -> ShowS
forall size. Show size => [Size size] -> ShowS
forall size. Show size => Size size -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Size size] -> ShowS
$cshowList :: forall size. Show size => [Size size] -> ShowS
show :: Size size -> String
$cshow :: forall size. Show size => Size size -> String
showsPrec :: Int -> Size size -> ShowS
$cshowsPrec :: forall size. Show size => Int -> Size size -> ShowS
Show)
data SSize (size :: Size Nat) where
SUncheckedSize :: Integer -> SSize 'UncheckedSize
SSize :: forall size. KnownNat size => SSize ('Size size)
deriving stock instance Show (SSize (size :: Size Nat))
type instance Sing = SSize
instance KnownNat size => SingI ('Size size) where
sing :: Sing ('Size size)
sing = forall (dims :: Nat). KnownNat dims => SSize ('Size dims)
SSize
type family SizeF (size :: Size Nat) :: Nat where
SizeF ('Size size) = size
instance SingKind (Size Nat) where
type Demote (Size Nat) = IsChecked Integer
fromSing :: forall (a :: Size Nat). Sing a -> Demote (Size Nat)
fromSing (SUncheckedSize Integer
size) = forall a. a -> IsChecked a
Unchecked Integer
size
fromSing (Sing a
SSize a
SSize :: Sing size) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(SizeF size)
toSing :: Demote (Size Nat) -> SomeSing (Size Nat)
toSing (Unchecked Integer
size) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> SSize 'UncheckedSize
SUncheckedSize forall a b. (a -> b) -> a -> b
$ Integer
size
toSing (Checked Integer
size) = case Integer -> Maybe SomeNat
someNatVal Integer
size of
Just (SomeNat (Proxy n
_ :: Proxy size)) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Nat). KnownNat dims => SSize ('Size dims)
SSize @size)
class KnownSize (size :: Size Nat) where
sizeVal :: Size Integer
instance KnownSize 'UncheckedSize where
sizeVal :: Size Integer
sizeVal = forall size. Size size
UncheckedSize
instance KnownNat size => KnownSize ('Size size) where
sizeVal :: Size Integer
sizeVal = forall size. size -> Size size
Size (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @size)
data Name (name :: Type) where
UncheckedName :: forall name. Name name
Name :: forall name. name -> Name name
deriving (Int -> Name name -> ShowS
forall name. Show name => Int -> Name name -> ShowS
forall name. Show name => [Name name] -> ShowS
forall name. Show name => Name name -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Name name] -> ShowS
$cshowList :: forall name. Show name => [Name name] -> ShowS
show :: Name name -> String
$cshow :: forall name. Show name => Name name -> String
showsPrec :: Int -> Name name -> ShowS
$cshowsPrec :: forall name. Show name => Int -> Name name -> ShowS
Show)
data SName (name :: Name Symbol) where
SUncheckedName :: String -> SName 'UncheckedName
SName :: forall name. KnownSymbol name => SName ('Name name)
deriving stock instance Show (SName (name :: Name Symbol))
pattern SNoName :: SName ('Name "*")
pattern $bSNoName :: SName ('Name "*")
$mSNoName :: forall {r}. SName ('Name "*") -> ((# #) -> r) -> ((# #) -> r) -> r
SNoName = SName
type instance Sing = SName
instance KnownSymbol name => SingI ('Name name) where
sing :: Sing ('Name name)
sing = forall (dims :: Symbol). KnownSymbol dims => SName ('Name dims)
SName
type family NameF (name :: Name Symbol) :: Symbol where
NameF ('Name name) = name
instance SingKind (Name Symbol) where
type Demote (Name Symbol) = IsChecked String
fromSing :: forall (a :: Name Symbol). Sing a -> Demote (Name Symbol)
fromSing (SUncheckedName String
name) = forall a. a -> IsChecked a
Unchecked String
name
fromSing (Sing a
SName a
SName :: Sing name) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(NameF name)
toSing :: Demote (Name Symbol) -> SomeSing (Name Symbol)
toSing (Unchecked String
name) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> SName 'UncheckedName
SUncheckedName forall a b. (a -> b) -> a -> b
$ String
name
toSing (Checked String
name) = case String -> SomeSymbol
someSymbolVal String
name of
SomeSymbol (Proxy n
_ :: Proxy name) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Symbol). KnownSymbol dims => SName ('Name dims)
SName @name)
class KnownName (name :: Name Symbol) where
nameVal :: Name String
instance KnownName 'UncheckedName where
nameVal :: Name String
nameVal = forall name. Name name
UncheckedName
instance KnownSymbol name => KnownName ('Name name) where
nameVal :: Name String
nameVal = forall name. name -> Name name
Name (forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name)
data Dim (name :: Type) (size :: Type) where
Dim ::
forall name size.
{ forall name size. Dim name size -> name
dimName :: name,
forall name size. Dim name size -> size
dimSize :: size
} ->
Dim name size
deriving (Dim name size -> Dim name size -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall name size.
(Eq name, Eq size) =>
Dim name size -> Dim name size -> Bool
/= :: Dim name size -> Dim name size -> Bool
$c/= :: forall name size.
(Eq name, Eq size) =>
Dim name size -> Dim name size -> Bool
== :: Dim name size -> Dim name size -> Bool
$c== :: forall name size.
(Eq name, Eq size) =>
Dim name size -> Dim name size -> Bool
Eq, Dim name size -> Dim name size -> Bool
Dim name size -> Dim name size -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {name} {size}. (Ord name, Ord size) => Eq (Dim name size)
forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Ordering
forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Dim name size
min :: Dim name size -> Dim name size -> Dim name size
$cmin :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Dim name size
max :: Dim name size -> Dim name size -> Dim name size
$cmax :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Dim name size
>= :: Dim name size -> Dim name size -> Bool
$c>= :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
> :: Dim name size -> Dim name size -> Bool
$c> :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
<= :: Dim name size -> Dim name size -> Bool
$c<= :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
< :: Dim name size -> Dim name size -> Bool
$c< :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
compare :: Dim name size -> Dim name size -> Ordering
$ccompare :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Ordering
Ord, Int -> Dim name size -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall name size.
(Show name, Show size) =>
Int -> Dim name size -> ShowS
forall name size.
(Show name, Show size) =>
[Dim name size] -> ShowS
forall name size. (Show name, Show size) => Dim name size -> String
showList :: [Dim name size] -> ShowS
$cshowList :: forall name size.
(Show name, Show size) =>
[Dim name size] -> ShowS
show :: Dim name size -> String
$cshow :: forall name size. (Show name, Show size) => Dim name size -> String
showsPrec :: Int -> Dim name size -> ShowS
$cshowsPrec :: forall name size.
(Show name, Show size) =>
Int -> Dim name size -> ShowS
Show)
instance Bifunctor Dim where
bimap :: forall a b c d. (a -> b) -> (c -> d) -> Dim a c -> Dim b d
bimap a -> b
f c -> d
g (Dim a
name c
size) = forall name size. name -> size -> Dim name size
Dim (a -> b
f a
name) (c -> d
g c
size)
data SDim (dim :: Dim (Name Symbol) (Size Nat)) where
SDim ::
forall name size.
{ forall (dims :: Name Symbol) (size :: Size Nat).
SDim ('Dim dims size) -> SName dims
sDimName :: SName name,
forall (dims :: Name Symbol) (size :: Size Nat).
SDim ('Dim dims size) -> SSize size
sDimSize :: SSize size
} ->
SDim ('Dim name size)
deriving stock instance Show (SDim (dim :: Dim (Name Symbol) (Size Nat)))
type instance Sing = SDim
instance (KnownSymbol name, KnownNat size) => SingI ('Dim ('Name name) ('Size size)) where
sing :: Sing ('Dim ('Name name) ('Size size))
sing = forall (dims :: Name Symbol) (size :: Size Nat).
SName dims -> SSize size -> SDim ('Dim dims size)
SDim (forall {k} (a :: k). SingI a => Sing a
sing @('Name name)) (forall {k} (a :: k). SingI a => Sing a
sing @('Size size))
instance SingKind (Dim (Name Symbol) (Size Nat)) where
type Demote (Dim (Name Symbol) (Size Nat)) = Dim (IsChecked String) (IsChecked Integer)
fromSing :: forall (a :: Dim (Name Symbol) (Size Nat)).
Sing a -> Demote (Dim (Name Symbol) (Size Nat))
fromSing (SDim SName name
name SSize size
size) = forall name size. name -> size -> Dim name size
Dim (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size)
toSing :: Demote (Dim (Name Symbol) (Size Nat))
-> SomeSing (Dim (Name Symbol) (Size Nat))
toSing (Dim IsChecked String
name IsChecked Integer
size) =
forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing IsChecked String
name forall a b. (a -> b) -> a -> b
$ \Sing a
name' ->
forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing IsChecked Integer
size forall a b. (a -> b) -> a -> b
$ \Sing a
size' ->
forall k (a :: k). Sing a -> SomeSing k
SomeSing forall a b. (a -> b) -> a -> b
$ forall (dims :: Name Symbol) (size :: Size Nat).
SName dims -> SSize size -> SDim ('Dim dims size)
SDim Sing a
name' Sing a
size'
pattern (:&:) ::
forall
(name :: Name Symbol)
(size :: Size Nat).
SName name ->
SSize size ->
SDim ('Dim name size)
pattern $b:&: :: forall (dims :: Name Symbol) (size :: Size Nat).
SName dims -> SSize size -> SDim ('Dim dims size)
$m:&: :: forall {r} {name :: Name Symbol} {size :: Size Nat}.
SDim ('Dim name size)
-> (SName name -> SSize size -> r) -> ((# #) -> r) -> r
(:&:) name size = SDim name size
infix 9 :&:
class KnownDim (dim :: Dim (Name Symbol) (Size Nat)) where
dimVal :: Dim (Name String) (Size Integer)
instance (KnownName name, KnownSize size) => KnownDim ('Dim name size) where
dimVal :: Dim (Name String) (Size Integer)
dimVal = forall name size. name -> size -> Dim name size
Dim (forall (name :: Name Symbol). KnownName name => Name String
nameVal @name) (forall (size :: Size Nat). KnownSize size => Size Integer
sizeVal @size)
data By (name :: Type) (index :: Type) where
ByName ::
forall name index.
name ->
By name index
ByIndex ::
forall name index.
index ->
By name index
deriving (Int -> By name index -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall name index.
(Show name, Show index) =>
Int -> By name index -> ShowS
forall name index.
(Show name, Show index) =>
[By name index] -> ShowS
forall name index.
(Show name, Show index) =>
By name index -> String
showList :: [By name index] -> ShowS
$cshowList :: forall name index.
(Show name, Show index) =>
[By name index] -> ShowS
show :: By name index -> String
$cshow :: forall name index.
(Show name, Show index) =>
By name index -> String
showsPrec :: Int -> By name index -> ShowS
$cshowsPrec :: forall name index.
(Show name, Show index) =>
Int -> By name index -> ShowS
Show, By name index -> By name index -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall name index.
(Eq name, Eq index) =>
By name index -> By name index -> Bool
/= :: By name index -> By name index -> Bool
$c/= :: forall name index.
(Eq name, Eq index) =>
By name index -> By name index -> Bool
== :: By name index -> By name index -> Bool
$c== :: forall name index.
(Eq name, Eq index) =>
By name index -> By name index -> Bool
Eq, By name index -> By name index -> Bool
By name index -> By name index -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {name} {index}. (Ord name, Ord index) => Eq (By name index)
forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Ordering
forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> By name index
min :: By name index -> By name index -> By name index
$cmin :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> By name index
max :: By name index -> By name index -> By name index
$cmax :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> By name index
>= :: By name index -> By name index -> Bool
$c>= :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
> :: By name index -> By name index -> Bool
$c> :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
<= :: By name index -> By name index -> Bool
$c<= :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
< :: By name index -> By name index -> Bool
$c< :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
compare :: By name index -> By name index -> Ordering
$ccompare :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Ordering
Ord)
data SBy (by :: By Symbol Nat) where
SByName :: forall name. KnownSymbol name => SBy ('ByName name)
SByIndex :: forall index. KnownNat index => SBy ('ByIndex index)
deriving stock instance Show (SBy (by :: By Symbol Nat))
type instance Sing = SBy
instance KnownSymbol name => SingI ('ByName name :: By Symbol Nat) where
sing :: Sing ('ByName name)
sing = forall (dims :: Symbol). KnownSymbol dims => SBy ('ByName dims)
SByName @name
instance KnownNat index => SingI ('ByIndex index :: By Symbol Nat) where
sing :: Sing ('ByIndex index)
sing = forall (dims :: Nat). KnownNat dims => SBy ('ByIndex dims)
SByIndex @index
type family ByNameF (by :: By Symbol Nat) :: Symbol where
ByNameF ('ByName name) = name
type family ByIndexF (by :: By Symbol Nat) :: Nat where
ByIndexF ('ByIndex index) = index
instance SingKind (By Symbol Nat) where
type Demote (By Symbol Nat) = By String Integer
fromSing :: forall (a :: By Symbol Nat). Sing a -> Demote (By Symbol Nat)
fromSing (Sing a
SBy a
SByName :: Sing by) = forall name index. name -> By name index
ByName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(ByNameF by)
fromSing (Sing a
SBy a
SByIndex :: Sing by) = forall name index. index -> By name index
ByIndex forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(ByIndexF by)
toSing :: Demote (By Symbol Nat) -> SomeSing (By Symbol Nat)
toSing (ByName String
name) = case String -> SomeSymbol
someSymbolVal String
name of
SomeSymbol (Proxy n
_ :: Proxy name) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Symbol). KnownSymbol dims => SBy ('ByName dims)
SByName @name)
toSing (ByIndex Integer
index) = case Integer -> Maybe SomeNat
someNatVal Integer
index of
Just (SomeNat (Proxy n
_ :: Proxy index)) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Nat). KnownNat dims => SBy ('ByIndex dims)
SByIndex @index)
class KnownBy (by :: By Symbol Nat) where
byVal :: By String Integer
instance
(KnownSymbol name) =>
KnownBy ('ByName name)
where
byVal :: By String Integer
byVal =
let name :: String
name = forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name
in forall name index. name -> By name index
ByName String
name
instance
(KnownNat index) =>
KnownBy ('ByIndex index)
where
byVal :: By String Integer
byVal =
let index :: Integer
index = forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @index
in forall name index. index -> By name index
ByIndex Integer
index
data SelectDim (by :: Type) where
UncheckedSelectDim :: forall by. SelectDim by
SelectDim :: forall by. by -> SelectDim by
data SSelectDim (selectDim :: SelectDim (By Symbol Nat)) where
SUncheckedSelectDim :: By String Integer -> SSelectDim 'UncheckedSelectDim
SSelectDim :: forall by. SBy by -> SSelectDim ('SelectDim by)
deriving stock instance Show (SSelectDim (selectDim :: SelectDim (By Symbol Nat)))
type instance Sing = SSelectDim
instance SingI (by :: By Symbol Nat) => SingI ('SelectDim by) where
sing :: Sing ('SelectDim by)
sing = forall (dims :: By Symbol Nat).
SBy dims -> SSelectDim ('SelectDim dims)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @by
instance SingKind (SelectDim (By Symbol Nat)) where
type Demote (SelectDim (By Symbol Nat)) = IsChecked (By String Integer)
fromSing :: forall (a :: SelectDim (By Symbol Nat)).
Sing a -> Demote (SelectDim (By Symbol Nat))
fromSing (SUncheckedSelectDim By String Integer
by) = forall a. a -> IsChecked a
Unchecked By String Integer
by
fromSing (SSelectDim SBy by
by) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SBy by
by
toSing :: Demote (SelectDim (By Symbol Nat))
-> SomeSing (SelectDim (By Symbol Nat))
toSing (Unchecked By String Integer
by) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. By String Integer -> SSelectDim 'UncheckedSelectDim
SUncheckedSelectDim forall a b. (a -> b) -> a -> b
$ By String Integer
by
toSing (Checked By String Integer
by) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing By String Integer
by forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dims :: By Symbol Nat).
SBy dims -> SSelectDim ('SelectDim dims)
SSelectDim
class KnownSelectDim (selectDim :: SelectDim (By Symbol Nat)) where
selectDimVal :: SelectDim (By String Integer)
instance KnownSelectDim 'UncheckedSelectDim where
selectDimVal :: SelectDim (By String Integer)
selectDimVal = forall by. SelectDim by
UncheckedSelectDim
instance (KnownBy by) => KnownSelectDim ('SelectDim by) where
selectDimVal :: SelectDim (By String Integer)
selectDimVal = let by :: By String Integer
by = forall (by :: By Symbol Nat). KnownBy by => By String Integer
byVal @by in forall by. by -> SelectDim by
SelectDim By String Integer
by
data SelectDims (selectDims :: Type) where
UncheckedSelectDims ::
forall selectDims.
SelectDims selectDims
SelectDims ::
forall selectDims.
selectDims ->
SelectDims selectDims
data SSelectDims (selectDims :: SelectDims [By Symbol Nat]) where
SUncheckedSelectDims :: [By String Integer] -> SSelectDims 'UncheckedSelectDims
SSelectDims :: forall bys. SList bys -> SSelectDims ('SelectDims bys)
deriving stock instance Show (SSelectDims (selectDims :: SelectDims [By Symbol Nat]))
type instance Sing = SSelectDims
instance SingI bys => SingI ('SelectDims (bys :: [By Symbol Nat])) where
sing :: Sing ('SelectDims bys)
sing = forall (dims :: [By Symbol Nat]).
SList dims -> SSelectDims ('SelectDims dims)
SSelectDims (forall {k} (a :: k). SingI a => Sing a
sing @bys)
instance SingKind (SelectDims [By Symbol Nat]) where
type Demote (SelectDims [By Symbol Nat]) = IsChecked [By String Integer]
fromSing :: forall (a :: SelectDims [By Symbol Nat]).
Sing a -> Demote (SelectDims [By Symbol Nat])
fromSing (SUncheckedSelectDims [By String Integer]
bys) = forall a. a -> IsChecked a
Unchecked [By String Integer]
bys
fromSing (SSelectDims SList bys
bys) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SList bys
bys
toSing :: Demote (SelectDims [By Symbol Nat])
-> SomeSing (SelectDims [By Symbol Nat])
toSing (Unchecked [By String Integer]
bys) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. [By String Integer] -> SSelectDims 'UncheckedSelectDims
SUncheckedSelectDims forall a b. (a -> b) -> a -> b
$ [By String Integer]
bys
toSing (Checked [By String Integer]
bys) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing [By String Integer]
bys forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dims :: [By Symbol Nat]).
SList dims -> SSelectDims ('SelectDims dims)
SSelectDims
class KnownSelectDims (selectDims :: SelectDims [By Symbol Nat]) where
selectDimsVal :: SelectDims [By String Integer]
instance KnownSelectDims 'UncheckedSelectDims where
selectDimsVal :: SelectDims [By String Integer]
selectDimsVal = forall selectDims. SelectDims selectDims
UncheckedSelectDims
instance KnownSelectDims ('SelectDims '[]) where
selectDimsVal :: SelectDims [By String Integer]
selectDimsVal = forall selectDims. selectDims -> SelectDims selectDims
SelectDims []
instance
(KnownBy by, KnownSelectDims ('SelectDims bys)) =>
KnownSelectDims ('SelectDims (by ': bys))
where
selectDimsVal :: SelectDims [By String Integer]
selectDimsVal =
let by :: By String Integer
by = forall (by :: By Symbol Nat). KnownBy by => By String Integer
byVal @by
SelectDims [By String Integer]
bys = forall (selectDims :: SelectDims [By Symbol Nat]).
KnownSelectDims selectDims =>
SelectDims [By String Integer]
selectDimsVal @('SelectDims bys)
in forall selectDims. selectDims -> SelectDims selectDims
SelectDims (By String Integer
by forall a. a -> [a] -> [a]
: [By String Integer]
bys)
data Shape (dims :: Type) where
UncheckedShape ::
forall dims.
Shape dims
Shape ::
forall dims.
dims ->
Shape dims
deriving (Int -> Shape dims -> ShowS
forall dims. Show dims => Int -> Shape dims -> ShowS
forall dims. Show dims => [Shape dims] -> ShowS
forall dims. Show dims => Shape dims -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Shape dims] -> ShowS
$cshowList :: forall dims. Show dims => [Shape dims] -> ShowS
show :: Shape dims -> String
$cshow :: forall dims. Show dims => Shape dims -> String
showsPrec :: Int -> Shape dims -> ShowS
$cshowsPrec :: forall dims. Show dims => Int -> Shape dims -> ShowS
Show)
data SShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]) where
SUncheckedShape :: [Dim String Integer] -> SShape 'UncheckedShape
SShape :: forall dims. SList dims -> SShape ('Shape dims)
deriving stock instance Show (SShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]))
type instance Sing = SShape
instance SingI dims => SingI ('Shape (dims :: [Dim (Name Symbol) (Size Nat)])) where
sing :: Sing ('Shape dims)
sing = forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dims
instance SingKind (Shape [Dim (Name Symbol) (Size Nat)]) where
type Demote (Shape [Dim (Name Symbol) (Size Nat)]) = IsChecked [Dim (IsChecked String) (IsChecked Integer)]
fromSing :: forall (a :: Shape [Dim (Name Symbol) (Size Nat)]).
Sing a -> Demote (Shape [Dim (Name Symbol) (Size Nat)])
fromSing (SUncheckedShape [Dim String Integer]
shape) =
forall a. a -> IsChecked a
Unchecked
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim String
name Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. a -> IsChecked a
Unchecked String
name) (forall a. a -> IsChecked a
Unchecked Integer
size))
forall a b. (a -> b) -> a -> b
$ [Dim String Integer]
shape
fromSing (SShape SList dims
dims) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SList dims
dims
toSing :: Demote (Shape [Dim (Name Symbol) (Size Nat)])
-> SomeSing (Shape [Dim (Name Symbol) (Size Nat)])
toSing (Unchecked [Dim (IsChecked String) (IsChecked Integer)]
shape) =
forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size))
forall a b. (a -> b) -> a -> b
$ [Dim (IsChecked String) (IsChecked Integer)]
shape
toSing (Checked [Dim (IsChecked String) (IsChecked Integer)]
shape) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing [Dim (IsChecked String) (IsChecked Integer)]
shape forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape
class KnownShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]) where
shapeVal :: Shape [Dim (Name String) (Size Integer)]
instance KnownShape 'UncheckedShape where
shapeVal :: Shape [Dim (Name String) (Size Integer)]
shapeVal = forall dims. Shape dims
UncheckedShape
instance KnownShape ('Shape '[]) where
shapeVal :: Shape [Dim (Name String) (Size Integer)]
shapeVal = forall dims. dims -> Shape dims
Shape []
instance (KnownShape ('Shape dims), KnownDim dim) => KnownShape ('Shape (dim ': dims)) where
shapeVal :: Shape [Dim (Name String) (Size Integer)]
shapeVal =
case forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
KnownShape shape =>
Shape [Dim (Name String) (Size Integer)]
shapeVal @('Shape dims) of
Shape [Dim (Name String) (Size Integer)]
dims -> forall dims. dims -> Shape dims
Shape forall a b. (a -> b) -> a -> b
$ forall (dim :: Dim (Name Symbol) (Size Nat)).
KnownDim dim =>
Dim (Name String) (Size Integer)
dimVal @dim forall a. a -> [a] -> [a]
: [Dim (Name String) (Size Integer)]
dims
type GetShapes :: k -> [Shape [Dim (Name Symbol) (Size Nat)]]
type family GetShapes f where
GetShapes (a :: Shape [Dim (Name Symbol) (Size Nat)]) = '[a]
GetShapes (f g) = Concat (GetShapes f) (GetShapes g)
GetShapes _ = '[]
instance Castable String (ForeignPtr ATen.Dimname) where
cast :: forall r. String -> (ForeignPtr Dimname -> IO r) -> IO r
cast String
name ForeignPtr Dimname -> IO r
f =
let ptr :: ForeignPtr Dimname
ptr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr StdString
str <- String -> IO (ForeignPtr StdString)
ATen.newStdString_s String
name
ForeignPtr Symbol
symbol <- ForeignPtr StdString -> IO (ForeignPtr Symbol)
ATen.dimname_s ForeignPtr StdString
str
ForeignPtr Symbol -> IO (ForeignPtr Dimname)
ATen.fromSymbol_s ForeignPtr Symbol
symbol
in ForeignPtr Dimname -> IO r
f ForeignPtr Dimname
ptr
uncast :: forall r. ForeignPtr Dimname -> (String -> IO r) -> IO r
uncast ForeignPtr Dimname
ptr String -> IO r
f =
let name :: String
name = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr Symbol
symbol <- ForeignPtr Dimname -> IO (ForeignPtr Symbol)
ATen.dimname_symbol ForeignPtr Dimname
ptr
ForeignPtr StdString
str <- ForeignPtr Symbol -> IO (ForeignPtr StdString)
ATen.symbol_toUnqualString ForeignPtr Symbol
symbol
ForeignPtr StdString -> IO String
ATen.string_c_str ForeignPtr StdString
str
in String -> IO r
f String
name
instance Castable [ForeignPtr ATen.Dimname] (ForeignPtr ATen.DimnameList) where
cast :: forall r.
[ForeignPtr Dimname] -> (ForeignPtr DimnameList -> IO r) -> IO r
cast [ForeignPtr Dimname]
names ForeignPtr DimnameList -> IO r
f =
let ptr :: ForeignPtr DimnameList
ptr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr DimnameList
list <- IO (ForeignPtr DimnameList)
ATen.newDimnameList
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ForeignPtr DimnameList -> ForeignPtr Dimname -> IO ()
ATen.dimnameList_push_back_n ForeignPtr DimnameList
list) [ForeignPtr Dimname]
names
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr DimnameList
list
in ForeignPtr DimnameList -> IO r
f ForeignPtr DimnameList
ptr
uncast :: forall r.
ForeignPtr DimnameList -> ([ForeignPtr Dimname] -> IO r) -> IO r
uncast ForeignPtr DimnameList
ptr [ForeignPtr Dimname] -> IO r
f =
let names :: [ForeignPtr Dimname]
names = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
CSize
len <- ForeignPtr DimnameList -> IO CSize
ATen.dimnameList_size ForeignPtr DimnameList
ptr
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr DimnameList -> CSize -> IO (ForeignPtr Dimname)
ATen.dimnameList_at_s ForeignPtr DimnameList
ptr) [CSize
0 .. (CSize
len forall a. Num a => a -> a -> a
- CSize
1)]
in [ForeignPtr Dimname] -> IO r
f [ForeignPtr Dimname]
names
instance Castable [String] (ForeignPtr ATen.DimnameList) where
cast :: forall r. [String] -> (ForeignPtr DimnameList -> IO r) -> IO r
cast [String]
xs ForeignPtr DimnameList -> IO r
f = do
[ForeignPtr Dimname]
ptrList <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\String
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast String
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Dimname))) [String]
xs
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Dimname]
ptrList ForeignPtr DimnameList -> IO r
f
uncast :: forall r. ForeignPtr DimnameList -> ([String] -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs [String] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Dimname]
ptrList -> do
[String]
names <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Dimname
x :: ForeignPtr ATen.Dimname) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Dimname
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Dimname]
ptrList
[String] -> IO r
f [String]
names
instance Castable [Integer] (ForeignPtr ATen.IntArray) where
cast :: forall r. [Integer] -> (ForeignPtr IntArray -> IO r) -> IO r
cast [Integer]
sizes ForeignPtr IntArray -> IO r
f =
let ptr :: ForeignPtr IntArray
ptr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr IntArray
array <- IO (ForeignPtr IntArray)
ATen.newIntArray
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ForeignPtr IntArray -> Int64 -> IO ()
ATen.intArray_push_back_l ForeignPtr IntArray
array forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger) [Integer]
sizes
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr IntArray
array
in ForeignPtr IntArray -> IO r
f ForeignPtr IntArray
ptr
uncast :: forall r. ForeignPtr IntArray -> ([Integer] -> IO r) -> IO r
uncast ForeignPtr IntArray
ptr [Integer] -> IO r
f =
let sizes :: [Integer]
sizes = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
CSize
len <- ForeignPtr IntArray -> IO CSize
ATen.intArray_size ForeignPtr IntArray
ptr
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
(<$>) forall a. Integral a => a -> Integer
toInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr IntArray -> CSize -> IO Int64
ATen.intArray_at_s ForeignPtr IntArray
ptr) [CSize
0 .. (CSize
len forall a. Num a => a -> a -> a
- CSize
1)]
in [Integer] -> IO r
f [Integer]
sizes