{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.Shape.Type where

import Data.Bifunctor (Bifunctor (..))
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.TypeLits (KnownNat, KnownSymbol, Nat, SomeNat (..), SomeSymbol (..), Symbol, natVal, someNatVal, someSymbolVal, symbolVal)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..), forgetIsChecked)
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Managed.Cast as ATen ()
import qualified Torch.Internal.Managed.Type.Dimname as ATen (dimname_symbol, fromSymbol_s)
import qualified Torch.Internal.Managed.Type.DimnameList as ATen (dimnameList_at_s, dimnameList_push_back_n, dimnameList_size, newDimnameList)
import qualified Torch.Internal.Managed.Type.IntArray as ATen (intArray_at_s, intArray_push_back_l, intArray_size, newIntArray)
import qualified Torch.Internal.Managed.Type.StdString as ATen (newStdString_s, string_c_str)
import qualified Torch.Internal.Managed.Type.Symbol as ATen (dimname_s, symbol_toUnqualString)
import qualified Torch.Internal.Type as ATen (Dimname, DimnameList, IntArray)

data Size (size :: Type) where
  UncheckedSize :: forall size. Size size
  Size :: forall size. size -> Size size
  deriving (Int -> Size size -> ShowS
forall size. Show size => Int -> Size size -> ShowS
forall size. Show size => [Size size] -> ShowS
forall size. Show size => Size size -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Size size] -> ShowS
$cshowList :: forall size. Show size => [Size size] -> ShowS
show :: Size size -> String
$cshow :: forall size. Show size => Size size -> String
showsPrec :: Int -> Size size -> ShowS
$cshowsPrec :: forall size. Show size => Int -> Size size -> ShowS
Show)

data SSize (size :: Size Nat) where
  SUncheckedSize :: Integer -> SSize 'UncheckedSize
  SSize :: forall size. KnownNat size => SSize ('Size size)

deriving stock instance Show (SSize (size :: Size Nat))

type instance Sing = SSize

instance KnownNat size => SingI ('Size size) where
  sing :: Sing ('Size size)
sing = forall (dims :: Nat). KnownNat dims => SSize ('Size dims)
SSize

type family SizeF (size :: Size Nat) :: Nat where
  SizeF ('Size size) = size

instance SingKind (Size Nat) where
  type Demote (Size Nat) = IsChecked Integer
  fromSing :: forall (a :: Size Nat). Sing a -> Demote (Size Nat)
fromSing (SUncheckedSize Integer
size) = forall a. a -> IsChecked a
Unchecked Integer
size
  fromSing (Sing a
SSize a
SSize :: Sing size) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(SizeF size)
  toSing :: Demote (Size Nat) -> SomeSing (Size Nat)
toSing (Unchecked Integer
size) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> SSize 'UncheckedSize
SUncheckedSize forall a b. (a -> b) -> a -> b
$ Integer
size
  toSing (Checked Integer
size) = case Integer -> Maybe SomeNat
someNatVal Integer
size of
    Just (SomeNat (Proxy n
_ :: Proxy size)) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Nat). KnownNat dims => SSize ('Size dims)
SSize @size)

class KnownSize (size :: Size Nat) where
  sizeVal :: Size Integer

instance KnownSize 'UncheckedSize where
  sizeVal :: Size Integer
sizeVal = forall size. Size size
UncheckedSize

instance KnownNat size => KnownSize ('Size size) where
  sizeVal :: Size Integer
sizeVal = forall size. size -> Size size
Size (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @size)

data Name (name :: Type) where
  UncheckedName :: forall name. Name name
  Name :: forall name. name -> Name name
  deriving (Int -> Name name -> ShowS
forall name. Show name => Int -> Name name -> ShowS
forall name. Show name => [Name name] -> ShowS
forall name. Show name => Name name -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Name name] -> ShowS
$cshowList :: forall name. Show name => [Name name] -> ShowS
show :: Name name -> String
$cshow :: forall name. Show name => Name name -> String
showsPrec :: Int -> Name name -> ShowS
$cshowsPrec :: forall name. Show name => Int -> Name name -> ShowS
Show)

data SName (name :: Name Symbol) where
  SUncheckedName :: String -> SName 'UncheckedName
  SName :: forall name. KnownSymbol name => SName ('Name name)

deriving stock instance Show (SName (name :: Name Symbol))

pattern SNoName :: SName ('Name "*")
pattern $bSNoName :: SName ('Name "*")
$mSNoName :: forall {r}. SName ('Name "*") -> ((# #) -> r) -> ((# #) -> r) -> r
SNoName = SName

type instance Sing = SName

instance KnownSymbol name => SingI ('Name name) where
  sing :: Sing ('Name name)
sing = forall (dims :: Symbol). KnownSymbol dims => SName ('Name dims)
SName

type family NameF (name :: Name Symbol) :: Symbol where
  NameF ('Name name) = name

instance SingKind (Name Symbol) where
  type Demote (Name Symbol) = IsChecked String
  fromSing :: forall (a :: Name Symbol). Sing a -> Demote (Name Symbol)
fromSing (SUncheckedName String
name) = forall a. a -> IsChecked a
Unchecked String
name
  fromSing (Sing a
SName a
SName :: Sing name) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(NameF name)
  toSing :: Demote (Name Symbol) -> SomeSing (Name Symbol)
toSing (Unchecked String
name) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> SName 'UncheckedName
SUncheckedName forall a b. (a -> b) -> a -> b
$ String
name
  toSing (Checked String
name) = case String -> SomeSymbol
someSymbolVal String
name of
    SomeSymbol (Proxy n
_ :: Proxy name) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Symbol). KnownSymbol dims => SName ('Name dims)
SName @name)

class KnownName (name :: Name Symbol) where
  nameVal :: Name String

instance KnownName 'UncheckedName where
  nameVal :: Name String
nameVal = forall name. Name name
UncheckedName

instance KnownSymbol name => KnownName ('Name name) where
  nameVal :: Name String
nameVal = forall name. name -> Name name
Name (forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name)

data Dim (name :: Type) (size :: Type) where
  Dim ::
    forall name size.
    { forall name size. Dim name size -> name
dimName :: name,
      forall name size. Dim name size -> size
dimSize :: size
    } ->
    Dim name size
  deriving (Dim name size -> Dim name size -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall name size.
(Eq name, Eq size) =>
Dim name size -> Dim name size -> Bool
/= :: Dim name size -> Dim name size -> Bool
$c/= :: forall name size.
(Eq name, Eq size) =>
Dim name size -> Dim name size -> Bool
== :: Dim name size -> Dim name size -> Bool
$c== :: forall name size.
(Eq name, Eq size) =>
Dim name size -> Dim name size -> Bool
Eq, Dim name size -> Dim name size -> Bool
Dim name size -> Dim name size -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {name} {size}. (Ord name, Ord size) => Eq (Dim name size)
forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Ordering
forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Dim name size
min :: Dim name size -> Dim name size -> Dim name size
$cmin :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Dim name size
max :: Dim name size -> Dim name size -> Dim name size
$cmax :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Dim name size
>= :: Dim name size -> Dim name size -> Bool
$c>= :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
> :: Dim name size -> Dim name size -> Bool
$c> :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
<= :: Dim name size -> Dim name size -> Bool
$c<= :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
< :: Dim name size -> Dim name size -> Bool
$c< :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Bool
compare :: Dim name size -> Dim name size -> Ordering
$ccompare :: forall name size.
(Ord name, Ord size) =>
Dim name size -> Dim name size -> Ordering
Ord, Int -> Dim name size -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall name size.
(Show name, Show size) =>
Int -> Dim name size -> ShowS
forall name size.
(Show name, Show size) =>
[Dim name size] -> ShowS
forall name size. (Show name, Show size) => Dim name size -> String
showList :: [Dim name size] -> ShowS
$cshowList :: forall name size.
(Show name, Show size) =>
[Dim name size] -> ShowS
show :: Dim name size -> String
$cshow :: forall name size. (Show name, Show size) => Dim name size -> String
showsPrec :: Int -> Dim name size -> ShowS
$cshowsPrec :: forall name size.
(Show name, Show size) =>
Int -> Dim name size -> ShowS
Show)

instance Bifunctor Dim where
  bimap :: forall a b c d. (a -> b) -> (c -> d) -> Dim a c -> Dim b d
bimap a -> b
f c -> d
g (Dim a
name c
size) = forall name size. name -> size -> Dim name size
Dim (a -> b
f a
name) (c -> d
g c
size)

data SDim (dim :: Dim (Name Symbol) (Size Nat)) where
  SDim ::
    forall name size.
    { forall (dims :: Name Symbol) (size :: Size Nat).
SDim ('Dim dims size) -> SName dims
sDimName :: SName name,
      forall (dims :: Name Symbol) (size :: Size Nat).
SDim ('Dim dims size) -> SSize size
sDimSize :: SSize size
    } ->
    SDim ('Dim name size)

deriving stock instance Show (SDim (dim :: Dim (Name Symbol) (Size Nat)))

type instance Sing = SDim

instance (KnownSymbol name, KnownNat size) => SingI ('Dim ('Name name) ('Size size)) where
  sing :: Sing ('Dim ('Name name) ('Size size))
sing = forall (dims :: Name Symbol) (size :: Size Nat).
SName dims -> SSize size -> SDim ('Dim dims size)
SDim (forall {k} (a :: k). SingI a => Sing a
sing @('Name name)) (forall {k} (a :: k). SingI a => Sing a
sing @('Size size))

instance SingKind (Dim (Name Symbol) (Size Nat)) where
  type Demote (Dim (Name Symbol) (Size Nat)) = Dim (IsChecked String) (IsChecked Integer)
  fromSing :: forall (a :: Dim (Name Symbol) (Size Nat)).
Sing a -> Demote (Dim (Name Symbol) (Size Nat))
fromSing (SDim SName name
name SSize size
size) = forall name size. name -> size -> Dim name size
Dim (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SName name
name) (forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SSize size
size)
  toSing :: Demote (Dim (Name Symbol) (Size Nat))
-> SomeSing (Dim (Name Symbol) (Size Nat))
toSing (Dim IsChecked String
name IsChecked Integer
size) =
    forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing IsChecked String
name forall a b. (a -> b) -> a -> b
$ \Sing a
name' ->
      forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing IsChecked Integer
size forall a b. (a -> b) -> a -> b
$ \Sing a
size' ->
        forall k (a :: k). Sing a -> SomeSing k
SomeSing forall a b. (a -> b) -> a -> b
$ forall (dims :: Name Symbol) (size :: Size Nat).
SName dims -> SSize size -> SDim ('Dim dims size)
SDim Sing a
name' Sing a
size'

pattern (:&:) ::
  forall
    (name :: Name Symbol)
    (size :: Size Nat).
  SName name ->
  SSize size ->
  SDim ('Dim name size)
pattern $b:&: :: forall (dims :: Name Symbol) (size :: Size Nat).
SName dims -> SSize size -> SDim ('Dim dims size)
$m:&: :: forall {r} {name :: Name Symbol} {size :: Size Nat}.
SDim ('Dim name size)
-> (SName name -> SSize size -> r) -> ((# #) -> r) -> r
(:&:) name size = SDim name size

infix 9 :&:

class KnownDim (dim :: Dim (Name Symbol) (Size Nat)) where
  dimVal :: Dim (Name String) (Size Integer)

instance (KnownName name, KnownSize size) => KnownDim ('Dim name size) where
  dimVal :: Dim (Name String) (Size Integer)
dimVal = forall name size. name -> size -> Dim name size
Dim (forall (name :: Name Symbol). KnownName name => Name String
nameVal @name) (forall (size :: Size Nat). KnownSize size => Size Integer
sizeVal @size)

-- | Data type to select dimensions by name or by index.
data By (name :: Type) (index :: Type) where
  -- | Select a dimension by name.
  ByName ::
    forall name index.
    name ->
    By name index
  -- | Select a dimension by index. Counting starts at zero for the first dimension.
  ByIndex ::
    forall name index.
    index ->
    By name index
  deriving (Int -> By name index -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall name index.
(Show name, Show index) =>
Int -> By name index -> ShowS
forall name index.
(Show name, Show index) =>
[By name index] -> ShowS
forall name index.
(Show name, Show index) =>
By name index -> String
showList :: [By name index] -> ShowS
$cshowList :: forall name index.
(Show name, Show index) =>
[By name index] -> ShowS
show :: By name index -> String
$cshow :: forall name index.
(Show name, Show index) =>
By name index -> String
showsPrec :: Int -> By name index -> ShowS
$cshowsPrec :: forall name index.
(Show name, Show index) =>
Int -> By name index -> ShowS
Show, By name index -> By name index -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall name index.
(Eq name, Eq index) =>
By name index -> By name index -> Bool
/= :: By name index -> By name index -> Bool
$c/= :: forall name index.
(Eq name, Eq index) =>
By name index -> By name index -> Bool
== :: By name index -> By name index -> Bool
$c== :: forall name index.
(Eq name, Eq index) =>
By name index -> By name index -> Bool
Eq, By name index -> By name index -> Bool
By name index -> By name index -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {name} {index}. (Ord name, Ord index) => Eq (By name index)
forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Ordering
forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> By name index
min :: By name index -> By name index -> By name index
$cmin :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> By name index
max :: By name index -> By name index -> By name index
$cmax :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> By name index
>= :: By name index -> By name index -> Bool
$c>= :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
> :: By name index -> By name index -> Bool
$c> :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
<= :: By name index -> By name index -> Bool
$c<= :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
< :: By name index -> By name index -> Bool
$c< :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Bool
compare :: By name index -> By name index -> Ordering
$ccompare :: forall name index.
(Ord name, Ord index) =>
By name index -> By name index -> Ordering
Ord)

data SBy (by :: By Symbol Nat) where
  SByName :: forall name. KnownSymbol name => SBy ('ByName name)
  SByIndex :: forall index. KnownNat index => SBy ('ByIndex index)

deriving stock instance Show (SBy (by :: By Symbol Nat))

type instance Sing = SBy

instance KnownSymbol name => SingI ('ByName name :: By Symbol Nat) where
  sing :: Sing ('ByName name)
sing = forall (dims :: Symbol). KnownSymbol dims => SBy ('ByName dims)
SByName @name

instance KnownNat index => SingI ('ByIndex index :: By Symbol Nat) where
  sing :: Sing ('ByIndex index)
sing = forall (dims :: Nat). KnownNat dims => SBy ('ByIndex dims)
SByIndex @index

type family ByNameF (by :: By Symbol Nat) :: Symbol where
  ByNameF ('ByName name) = name

type family ByIndexF (by :: By Symbol Nat) :: Nat where
  ByIndexF ('ByIndex index) = index

instance SingKind (By Symbol Nat) where
  type Demote (By Symbol Nat) = By String Integer
  fromSing :: forall (a :: By Symbol Nat). Sing a -> Demote (By Symbol Nat)
fromSing (Sing a
SBy a
SByName :: Sing by) = forall name index. name -> By name index
ByName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(ByNameF by)
  fromSing (Sing a
SBy a
SByIndex :: Sing by) = forall name index. index -> By name index
ByIndex forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(ByIndexF by)
  toSing :: Demote (By Symbol Nat) -> SomeSing (By Symbol Nat)
toSing (ByName String
name) = case String -> SomeSymbol
someSymbolVal String
name of
    SomeSymbol (Proxy n
_ :: Proxy name) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Symbol). KnownSymbol dims => SBy ('ByName dims)
SByName @name)
  toSing (ByIndex Integer
index) = case Integer -> Maybe SomeNat
someNatVal Integer
index of
    Just (SomeNat (Proxy n
_ :: Proxy index)) -> forall k (a :: k). Sing a -> SomeSing k
SomeSing (forall (dims :: Nat). KnownNat dims => SBy ('ByIndex dims)
SByIndex @index)

class KnownBy (by :: By Symbol Nat) where
  byVal :: By String Integer

instance
  (KnownSymbol name) =>
  KnownBy ('ByName name)
  where
  byVal :: By String Integer
byVal =
    let name :: String
name = forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @name
     in forall name index. name -> By name index
ByName String
name

instance
  (KnownNat index) =>
  KnownBy ('ByIndex index)
  where
  byVal :: By String Integer
byVal =
    let index :: Integer
index = forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @index
     in forall name index. index -> By name index
ByIndex Integer
index

data SelectDim (by :: Type) where
  -- | Unknown method of dimension selection.
  UncheckedSelectDim :: forall by. SelectDim by
  -- | Known method of dimension selection, that is, either by name or by index.
  SelectDim :: forall by. by -> SelectDim by

data SSelectDim (selectDim :: SelectDim (By Symbol Nat)) where
  SUncheckedSelectDim :: By String Integer -> SSelectDim 'UncheckedSelectDim
  SSelectDim :: forall by. SBy by -> SSelectDim ('SelectDim by)

deriving stock instance Show (SSelectDim (selectDim :: SelectDim (By Symbol Nat)))

type instance Sing = SSelectDim

instance SingI (by :: By Symbol Nat) => SingI ('SelectDim by) where
  sing :: Sing ('SelectDim by)
sing = forall (dims :: By Symbol Nat).
SBy dims -> SSelectDim ('SelectDim dims)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @by

instance SingKind (SelectDim (By Symbol Nat)) where
  type Demote (SelectDim (By Symbol Nat)) = IsChecked (By String Integer)
  fromSing :: forall (a :: SelectDim (By Symbol Nat)).
Sing a -> Demote (SelectDim (By Symbol Nat))
fromSing (SUncheckedSelectDim By String Integer
by) = forall a. a -> IsChecked a
Unchecked By String Integer
by
  fromSing (SSelectDim SBy by
by) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SBy by
by
  toSing :: Demote (SelectDim (By Symbol Nat))
-> SomeSing (SelectDim (By Symbol Nat))
toSing (Unchecked By String Integer
by) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. By String Integer -> SSelectDim 'UncheckedSelectDim
SUncheckedSelectDim forall a b. (a -> b) -> a -> b
$ By String Integer
by
  toSing (Checked By String Integer
by) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing By String Integer
by forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dims :: By Symbol Nat).
SBy dims -> SSelectDim ('SelectDim dims)
SSelectDim

class KnownSelectDim (selectDim :: SelectDim (By Symbol Nat)) where
  selectDimVal :: SelectDim (By String Integer)

instance KnownSelectDim 'UncheckedSelectDim where
  selectDimVal :: SelectDim (By String Integer)
selectDimVal = forall by. SelectDim by
UncheckedSelectDim

instance (KnownBy by) => KnownSelectDim ('SelectDim by) where
  selectDimVal :: SelectDim (By String Integer)
selectDimVal = let by :: By String Integer
by = forall (by :: By Symbol Nat). KnownBy by => By String Integer
byVal @by in forall by. by -> SelectDim by
SelectDim By String Integer
by

data SelectDims (selectDims :: Type) where
  UncheckedSelectDims ::
    forall selectDims.
    SelectDims selectDims
  SelectDims ::
    forall selectDims.
    selectDims ->
    SelectDims selectDims

data SSelectDims (selectDims :: SelectDims [By Symbol Nat]) where
  SUncheckedSelectDims :: [By String Integer] -> SSelectDims 'UncheckedSelectDims
  SSelectDims :: forall bys. SList bys -> SSelectDims ('SelectDims bys)

deriving stock instance Show (SSelectDims (selectDims :: SelectDims [By Symbol Nat]))

type instance Sing = SSelectDims

instance SingI bys => SingI ('SelectDims (bys :: [By Symbol Nat])) where
  sing :: Sing ('SelectDims bys)
sing = forall (dims :: [By Symbol Nat]).
SList dims -> SSelectDims ('SelectDims dims)
SSelectDims (forall {k} (a :: k). SingI a => Sing a
sing @bys)

instance SingKind (SelectDims [By Symbol Nat]) where
  type Demote (SelectDims [By Symbol Nat]) = IsChecked [By String Integer]
  fromSing :: forall (a :: SelectDims [By Symbol Nat]).
Sing a -> Demote (SelectDims [By Symbol Nat])
fromSing (SUncheckedSelectDims [By String Integer]
bys) = forall a. a -> IsChecked a
Unchecked [By String Integer]
bys
  fromSing (SSelectDims SList bys
bys) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SList bys
bys
  toSing :: Demote (SelectDims [By Symbol Nat])
-> SomeSing (SelectDims [By Symbol Nat])
toSing (Unchecked [By String Integer]
bys) = forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. [By String Integer] -> SSelectDims 'UncheckedSelectDims
SUncheckedSelectDims forall a b. (a -> b) -> a -> b
$ [By String Integer]
bys
  toSing (Checked [By String Integer]
bys) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing [By String Integer]
bys forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dims :: [By Symbol Nat]).
SList dims -> SSelectDims ('SelectDims dims)
SSelectDims

class KnownSelectDims (selectDims :: SelectDims [By Symbol Nat]) where
  selectDimsVal :: SelectDims [By String Integer]

instance KnownSelectDims 'UncheckedSelectDims where
  selectDimsVal :: SelectDims [By String Integer]
selectDimsVal = forall selectDims. SelectDims selectDims
UncheckedSelectDims

instance KnownSelectDims ('SelectDims '[]) where
  selectDimsVal :: SelectDims [By String Integer]
selectDimsVal = forall selectDims. selectDims -> SelectDims selectDims
SelectDims []

instance
  (KnownBy by, KnownSelectDims ('SelectDims bys)) =>
  KnownSelectDims ('SelectDims (by ': bys))
  where
  selectDimsVal :: SelectDims [By String Integer]
selectDimsVal =
    let by :: By String Integer
by = forall (by :: By Symbol Nat). KnownBy by => By String Integer
byVal @by
        SelectDims [By String Integer]
bys = forall (selectDims :: SelectDims [By Symbol Nat]).
KnownSelectDims selectDims =>
SelectDims [By String Integer]
selectDimsVal @('SelectDims bys)
     in forall selectDims. selectDims -> SelectDims selectDims
SelectDims (By String Integer
by forall a. a -> [a] -> [a]
: [By String Integer]
bys)

-- | Data type to represent tensor shapes, that is, lists of dimensions.
data Shape (dims :: Type) where
  -- | The shape is fully unchecked.
  -- Neither the number of the dimensions
  -- nor any dimension properties are known to the compiler.
  UncheckedShape ::
    forall dims.
    Shape dims
  -- | The shape is partially known to the compiler.
  -- The list of dimensions has a known length, but may contain 'UncheckedDim', that is, unknown dimensions.
  Shape ::
    forall dims.
    dims ->
    Shape dims
  deriving (Int -> Shape dims -> ShowS
forall dims. Show dims => Int -> Shape dims -> ShowS
forall dims. Show dims => [Shape dims] -> ShowS
forall dims. Show dims => Shape dims -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Shape dims] -> ShowS
$cshowList :: forall dims. Show dims => [Shape dims] -> ShowS
show :: Shape dims -> String
$cshow :: forall dims. Show dims => Shape dims -> String
showsPrec :: Int -> Shape dims -> ShowS
$cshowsPrec :: forall dims. Show dims => Int -> Shape dims -> ShowS
Show)

data SShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]) where
  SUncheckedShape :: [Dim String Integer] -> SShape 'UncheckedShape
  SShape :: forall dims. SList dims -> SShape ('Shape dims)

deriving stock instance Show (SShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]))

type instance Sing = SShape

instance SingI dims => SingI ('Shape (dims :: [Dim (Name Symbol) (Size Nat)])) where
  sing :: Sing ('Shape dims)
sing = forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @dims

instance SingKind (Shape [Dim (Name Symbol) (Size Nat)]) where
  type Demote (Shape [Dim (Name Symbol) (Size Nat)]) = IsChecked [Dim (IsChecked String) (IsChecked Integer)]
  fromSing :: forall (a :: Shape [Dim (Name Symbol) (Size Nat)]).
Sing a -> Demote (Shape [Dim (Name Symbol) (Size Nat)])
fromSing (SUncheckedShape [Dim String Integer]
shape) =
    forall a. a -> IsChecked a
Unchecked
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim String
name Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. a -> IsChecked a
Unchecked String
name) (forall a. a -> IsChecked a
Unchecked Integer
size))
      forall a b. (a -> b) -> a -> b
$ [Dim String Integer]
shape
  fromSing (SShape SList dims
dims) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SList dims
dims
  toSing :: Demote (Shape [Dim (Name Symbol) (Size Nat)])
-> SomeSing (Shape [Dim (Name Symbol) (Size Nat)])
toSing (Unchecked [Dim (IsChecked String) (IsChecked Integer)]
shape) =
    forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim IsChecked String
name IsChecked Integer
size) -> forall name size. name -> size -> Dim name size
Dim (forall a. IsChecked a -> a
forgetIsChecked IsChecked String
name) (forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size))
      forall a b. (a -> b) -> a -> b
$ [Dim (IsChecked String) (IsChecked Integer)]
shape
  toSing (Checked [Dim (IsChecked String) (IsChecked Integer)]
shape) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing [Dim (IsChecked String) (IsChecked Integer)]
shape forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape

class KnownShape (shape :: Shape [Dim (Name Symbol) (Size Nat)]) where
  shapeVal :: Shape [Dim (Name String) (Size Integer)]

instance KnownShape 'UncheckedShape where
  shapeVal :: Shape [Dim (Name String) (Size Integer)]
shapeVal = forall dims. Shape dims
UncheckedShape

instance KnownShape ('Shape '[]) where
  shapeVal :: Shape [Dim (Name String) (Size Integer)]
shapeVal = forall dims. dims -> Shape dims
Shape []

instance (KnownShape ('Shape dims), KnownDim dim) => KnownShape ('Shape (dim ': dims)) where
  shapeVal :: Shape [Dim (Name String) (Size Integer)]
shapeVal =
    case forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
KnownShape shape =>
Shape [Dim (Name String) (Size Integer)]
shapeVal @('Shape dims) of
      Shape [Dim (Name String) (Size Integer)]
dims -> forall dims. dims -> Shape dims
Shape forall a b. (a -> b) -> a -> b
$ forall (dim :: Dim (Name Symbol) (Size Nat)).
KnownDim dim =>
Dim (Name String) (Size Integer)
dimVal @dim forall a. a -> [a] -> [a]
: [Dim (Name String) (Size Integer)]
dims

-- >>> :kind! GetShapes ('Shape '[ 'Dim ('Name "*") ('Size 1)])
-- GetShapes ('Shape '[ 'Dim ('Name "*") ('Size 1)]) :: [Shape
--                                                         [Dim (Name Symbol) (Size Nat)]]
-- = '[ 'Shape '[ 'Dim ('Name "*") ('Size 1)]]
-- >>> :kind! GetShapes '[ 'Shape '[ 'Dim ('Name "*") ('Size 1)], 'Shape '[ 'Dim 'UncheckedName ('Size 0)]]
-- GetShapes '[ 'Shape '[ 'Dim ('Name "*") ('Size 1)], 'Shape '[ 'Dim 'UncheckedName ('Size 0)]] :: [Shape
--                                                                                                     [Dim
--                                                                                                        (Name
--                                                                                                           Symbol)
--                                                                                                        (Size
--                                                                                                           Nat)]]
-- = '[ 'Shape '[ 'Dim ('Name "*") ('Size 1)],
--      'Shape '[ 'Dim 'UncheckedName ('Size 0)]]
-- >>> :kind! GetShapes ('Just ('Shape '[ 'Dim ('Name "*") ('Size 1)]))
-- GetShapes ('Just ('Shape '[ 'Dim ('Name "*") ('Size 1)])) :: [Shape
--                                                                 [Dim (Name Symbol) (Size Nat)]]
-- = '[ 'Shape '[ 'Dim ('Name "*") ('Size 1)]]
type GetShapes :: k -> [Shape [Dim (Name Symbol) (Size Nat)]]
type family GetShapes f where
  GetShapes (a :: Shape [Dim (Name Symbol) (Size Nat)]) = '[a]
  GetShapes (f g) = Concat (GetShapes f) (GetShapes g)
  GetShapes _ = '[]

instance Castable String (ForeignPtr ATen.Dimname) where
  cast :: forall r. String -> (ForeignPtr Dimname -> IO r) -> IO r
cast String
name ForeignPtr Dimname -> IO r
f =
    let ptr :: ForeignPtr Dimname
ptr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
          ForeignPtr StdString
str <- String -> IO (ForeignPtr StdString)
ATen.newStdString_s String
name
          ForeignPtr Symbol
symbol <- ForeignPtr StdString -> IO (ForeignPtr Symbol)
ATen.dimname_s ForeignPtr StdString
str
          ForeignPtr Symbol -> IO (ForeignPtr Dimname)
ATen.fromSymbol_s ForeignPtr Symbol
symbol
     in ForeignPtr Dimname -> IO r
f ForeignPtr Dimname
ptr
  uncast :: forall r. ForeignPtr Dimname -> (String -> IO r) -> IO r
uncast ForeignPtr Dimname
ptr String -> IO r
f =
    let name :: String
name = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
          ForeignPtr Symbol
symbol <- ForeignPtr Dimname -> IO (ForeignPtr Symbol)
ATen.dimname_symbol ForeignPtr Dimname
ptr
          ForeignPtr StdString
str <- ForeignPtr Symbol -> IO (ForeignPtr StdString)
ATen.symbol_toUnqualString ForeignPtr Symbol
symbol
          ForeignPtr StdString -> IO String
ATen.string_c_str ForeignPtr StdString
str
     in String -> IO r
f String
name

instance Castable [ForeignPtr ATen.Dimname] (ForeignPtr ATen.DimnameList) where
  cast :: forall r.
[ForeignPtr Dimname] -> (ForeignPtr DimnameList -> IO r) -> IO r
cast [ForeignPtr Dimname]
names ForeignPtr DimnameList -> IO r
f =
    let ptr :: ForeignPtr DimnameList
ptr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
          ForeignPtr DimnameList
list <- IO (ForeignPtr DimnameList)
ATen.newDimnameList
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ForeignPtr DimnameList -> ForeignPtr Dimname -> IO ()
ATen.dimnameList_push_back_n ForeignPtr DimnameList
list) [ForeignPtr Dimname]
names
          forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr DimnameList
list
     in ForeignPtr DimnameList -> IO r
f ForeignPtr DimnameList
ptr
  uncast :: forall r.
ForeignPtr DimnameList -> ([ForeignPtr Dimname] -> IO r) -> IO r
uncast ForeignPtr DimnameList
ptr [ForeignPtr Dimname] -> IO r
f =
    let names :: [ForeignPtr Dimname]
names = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
          CSize
len <- ForeignPtr DimnameList -> IO CSize
ATen.dimnameList_size ForeignPtr DimnameList
ptr
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr DimnameList -> CSize -> IO (ForeignPtr Dimname)
ATen.dimnameList_at_s ForeignPtr DimnameList
ptr) [CSize
0 .. (CSize
len forall a. Num a => a -> a -> a
- CSize
1)]
     in [ForeignPtr Dimname] -> IO r
f [ForeignPtr Dimname]
names

instance Castable [String] (ForeignPtr ATen.DimnameList) where
  cast :: forall r. [String] -> (ForeignPtr DimnameList -> IO r) -> IO r
cast [String]
xs ForeignPtr DimnameList -> IO r
f = do
    [ForeignPtr Dimname]
ptrList <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\String
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast String
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Dimname))) [String]
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Dimname]
ptrList ForeignPtr DimnameList -> IO r
f
  uncast :: forall r. ForeignPtr DimnameList -> ([String] -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs [String] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Dimname]
ptrList -> do
    [String]
names <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Dimname
x :: ForeignPtr ATen.Dimname) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Dimname
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Dimname]
ptrList
    [String] -> IO r
f [String]
names

instance Castable [Integer] (ForeignPtr ATen.IntArray) where
  cast :: forall r. [Integer] -> (ForeignPtr IntArray -> IO r) -> IO r
cast [Integer]
sizes ForeignPtr IntArray -> IO r
f =
    let ptr :: ForeignPtr IntArray
ptr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
          ForeignPtr IntArray
array <- IO (ForeignPtr IntArray)
ATen.newIntArray
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ForeignPtr IntArray -> Int64 -> IO ()
ATen.intArray_push_back_l ForeignPtr IntArray
array forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger) [Integer]
sizes
          forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr IntArray
array
     in ForeignPtr IntArray -> IO r
f ForeignPtr IntArray
ptr
  uncast :: forall r. ForeignPtr IntArray -> ([Integer] -> IO r) -> IO r
uncast ForeignPtr IntArray
ptr [Integer] -> IO r
f =
    let sizes :: [Integer]
sizes = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
          CSize
len <- ForeignPtr IntArray -> IO CSize
ATen.intArray_size ForeignPtr IntArray
ptr
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
(<$>) forall a. Integral a => a -> Integer
toInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr IntArray -> CSize -> IO Int64
ATen.intArray_at_s ForeignPtr IntArray
ptr) [CSize
0 .. (CSize
len forall a. Num a => a -> a -> a
- CSize
1)]
     in [Integer] -> IO r
f [Integer]
sizes