{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
module Torch.GraduallyTyped.RequiresGradient where
import Data.Kind (Type)
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Data.Singletons.TH (genSingletons)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))
data RequiresGradient
=
WithGradient
|
WithoutGradient
deriving (Int -> RequiresGradient -> ShowS
[RequiresGradient] -> ShowS
RequiresGradient -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RequiresGradient] -> ShowS
$cshowList :: [RequiresGradient] -> ShowS
show :: RequiresGradient -> String
$cshow :: RequiresGradient -> String
showsPrec :: Int -> RequiresGradient -> ShowS
$cshowsPrec :: Int -> RequiresGradient -> ShowS
Show, RequiresGradient -> RequiresGradient -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RequiresGradient -> RequiresGradient -> Bool
$c/= :: RequiresGradient -> RequiresGradient -> Bool
== :: RequiresGradient -> RequiresGradient -> Bool
$c== :: RequiresGradient -> RequiresGradient -> Bool
Eq)
genSingletons [''RequiresGradient]
deriving stock instance Show (SRequiresGradient (requiresGradient :: RequiresGradient))
class KnownRequiresGradient (requiresGradient :: RequiresGradient) where
requiresGradientVal :: RequiresGradient
instance KnownRequiresGradient 'WithGradient where
requiresGradientVal :: RequiresGradient
requiresGradientVal = RequiresGradient
WithGradient
instance KnownRequiresGradient 'WithoutGradient where
requiresGradientVal :: RequiresGradient
requiresGradientVal = RequiresGradient
WithoutGradient
data Gradient (requiresGradient :: Type) where
UncheckedGradient :: forall requiresGradient. Gradient requiresGradient
Gradient :: forall requiresGradient. requiresGradient -> Gradient requiresGradient
deriving (Int -> Gradient requiresGradient -> ShowS
forall requiresGradient.
Show requiresGradient =>
Int -> Gradient requiresGradient -> ShowS
forall requiresGradient.
Show requiresGradient =>
[Gradient requiresGradient] -> ShowS
forall requiresGradient.
Show requiresGradient =>
Gradient requiresGradient -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Gradient requiresGradient] -> ShowS
$cshowList :: forall requiresGradient.
Show requiresGradient =>
[Gradient requiresGradient] -> ShowS
show :: Gradient requiresGradient -> String
$cshow :: forall requiresGradient.
Show requiresGradient =>
Gradient requiresGradient -> String
showsPrec :: Int -> Gradient requiresGradient -> ShowS
$cshowsPrec :: forall requiresGradient.
Show requiresGradient =>
Int -> Gradient requiresGradient -> ShowS
Show)
data SGradient (gradient :: Gradient RequiresGradient) where
SUncheckedGradient :: RequiresGradient -> SGradient 'UncheckedGradient
SGradient :: forall requiresGradient. SRequiresGradient requiresGradient -> SGradient ('Gradient requiresGradient)
deriving stock instance Show (SGradient (requiresGradient :: Gradient RequiresGradient))
type instance Sing = SGradient
instance SingI requiresGradient => SingI ('Gradient (requiresGradient :: RequiresGradient)) where
sing :: Sing ('Gradient requiresGradient)
sing = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @requiresGradient
instance SingKind (Gradient RequiresGradient) where
type Demote (Gradient RequiresGradient) = IsChecked RequiresGradient
fromSing :: forall (a :: Gradient RequiresGradient).
Sing a -> Demote (Gradient RequiresGradient)
fromSing (SUncheckedGradient RequiresGradient
requiresGradient) = forall a. a -> IsChecked a
Unchecked RequiresGradient
requiresGradient
fromSing (SGradient SRequiresGradient requiresGradient
requiresGradient) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SRequiresGradient requiresGradient
requiresGradient
toSing :: Demote (Gradient RequiresGradient)
-> SomeSing (Gradient RequiresGradient)
toSing (Unchecked RequiresGradient
requiresGradient) = forall k (a :: k). Sing a -> SomeSing k
SomeSing (RequiresGradient -> SGradient 'UncheckedGradient
SUncheckedGradient RequiresGradient
requiresGradient)
toSing (Checked RequiresGradient
requiresGradient) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing RequiresGradient
requiresGradient forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient
class KnownGradient (gradient :: Gradient RequiresGradient) where
gradientVal :: Gradient RequiresGradient
instance KnownGradient 'UncheckedGradient where
gradientVal :: Gradient RequiresGradient
gradientVal = forall requiresGradient. Gradient requiresGradient
UncheckedGradient
instance (KnownRequiresGradient requiresGradient) => KnownGradient ('Gradient requiresGradient) where
gradientVal :: Gradient RequiresGradient
gradientVal = forall requiresGradient.
requiresGradient -> Gradient requiresGradient
Gradient (forall (requiresGradient :: RequiresGradient).
KnownRequiresGradient requiresGradient =>
RequiresGradient
requiresGradientVal @requiresGradient)
type GetGradients :: k -> [Gradient RequiresGradient]
type family GetGradients f where
GetGradients (a :: Gradient RequiresGradient) = '[a]
GetGradients (f g) = Concat (GetGradients f) (GetGradients g)
GetGradients _ = '[]