{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.GraduallyTyped.RequiresGradient where

import Data.Kind (Type)
import Data.Singletons (Sing, SingI (..), SingKind (..), SomeSing (..), withSomeSing)
import Data.Singletons.TH (genSingletons)
import Torch.GraduallyTyped.Prelude (Concat, IsChecked (..))

-- | Data type to represent whether or not the tensor requires gradient computations.
data RequiresGradient
  = -- | The tensor requires gradient computations.
    WithGradient
  | -- | Gradient computations for this tensor are disabled.
    WithoutGradient
  deriving (Int -> RequiresGradient -> ShowS
[RequiresGradient] -> ShowS
RequiresGradient -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RequiresGradient] -> ShowS
$cshowList :: [RequiresGradient] -> ShowS
show :: RequiresGradient -> String
$cshow :: RequiresGradient -> String
showsPrec :: Int -> RequiresGradient -> ShowS
$cshowsPrec :: Int -> RequiresGradient -> ShowS
Show, RequiresGradient -> RequiresGradient -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RequiresGradient -> RequiresGradient -> Bool
$c/= :: RequiresGradient -> RequiresGradient -> Bool
== :: RequiresGradient -> RequiresGradient -> Bool
$c== :: RequiresGradient -> RequiresGradient -> Bool
Eq)

genSingletons [''RequiresGradient]

deriving stock instance Show (SRequiresGradient (requiresGradient :: RequiresGradient))

class KnownRequiresGradient (requiresGradient :: RequiresGradient) where
  requiresGradientVal :: RequiresGradient

instance KnownRequiresGradient 'WithGradient where
  requiresGradientVal :: RequiresGradient
requiresGradientVal = RequiresGradient
WithGradient

instance KnownRequiresGradient 'WithoutGradient where
  requiresGradientVal :: RequiresGradient
requiresGradientVal = RequiresGradient
WithoutGradient

-- | Data type to represent whether or not it is known by the compiler if the tensor requires gradient computations.
data Gradient (requiresGradient :: Type) where
  -- | Whether or not the tensor requires gradient computations is unknown to the compiler.
  UncheckedGradient :: forall requiresGradient. Gradient requiresGradient
  -- | Whether or not the tensor requires gradient computations is known to the compiler.
  Gradient :: forall requiresGradient. requiresGradient -> Gradient requiresGradient
  deriving (Int -> Gradient requiresGradient -> ShowS
forall requiresGradient.
Show requiresGradient =>
Int -> Gradient requiresGradient -> ShowS
forall requiresGradient.
Show requiresGradient =>
[Gradient requiresGradient] -> ShowS
forall requiresGradient.
Show requiresGradient =>
Gradient requiresGradient -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Gradient requiresGradient] -> ShowS
$cshowList :: forall requiresGradient.
Show requiresGradient =>
[Gradient requiresGradient] -> ShowS
show :: Gradient requiresGradient -> String
$cshow :: forall requiresGradient.
Show requiresGradient =>
Gradient requiresGradient -> String
showsPrec :: Int -> Gradient requiresGradient -> ShowS
$cshowsPrec :: forall requiresGradient.
Show requiresGradient =>
Int -> Gradient requiresGradient -> ShowS
Show)

data SGradient (gradient :: Gradient RequiresGradient) where
  SUncheckedGradient :: RequiresGradient -> SGradient 'UncheckedGradient
  SGradient :: forall requiresGradient. SRequiresGradient requiresGradient -> SGradient ('Gradient requiresGradient)

deriving stock instance Show (SGradient (requiresGradient :: Gradient RequiresGradient))

type instance Sing = SGradient

instance SingI requiresGradient => SingI ('Gradient (requiresGradient :: RequiresGradient)) where
  sing :: Sing ('Gradient requiresGradient)
sing = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). SingI a => Sing a
sing @requiresGradient

instance SingKind (Gradient RequiresGradient) where
  type Demote (Gradient RequiresGradient) = IsChecked RequiresGradient
  fromSing :: forall (a :: Gradient RequiresGradient).
Sing a -> Demote (Gradient RequiresGradient)
fromSing (SUncheckedGradient RequiresGradient
requiresGradient) = forall a. a -> IsChecked a
Unchecked RequiresGradient
requiresGradient
  fromSing (SGradient SRequiresGradient requiresGradient
requiresGradient) = forall a. a -> IsChecked a
Checked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SRequiresGradient requiresGradient
requiresGradient
  toSing :: Demote (Gradient RequiresGradient)
-> SomeSing (Gradient RequiresGradient)
toSing (Unchecked RequiresGradient
requiresGradient) = forall k (a :: k). Sing a -> SomeSing k
SomeSing (RequiresGradient -> SGradient 'UncheckedGradient
SUncheckedGradient RequiresGradient
requiresGradient)
  toSing (Checked RequiresGradient
requiresGradient) = forall k r.
SingKind k =>
Demote k -> (forall (a :: k). Sing a -> r) -> r
withSomeSing RequiresGradient
requiresGradient forall a b. (a -> b) -> a -> b
$ forall k (a :: k). Sing a -> SomeSing k
SomeSing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient

class KnownGradient (gradient :: Gradient RequiresGradient) where
  gradientVal :: Gradient RequiresGradient

instance KnownGradient 'UncheckedGradient where
  gradientVal :: Gradient RequiresGradient
gradientVal = forall requiresGradient. Gradient requiresGradient
UncheckedGradient

instance (KnownRequiresGradient requiresGradient) => KnownGradient ('Gradient requiresGradient) where
  gradientVal :: Gradient RequiresGradient
gradientVal = forall requiresGradient.
requiresGradient -> Gradient requiresGradient
Gradient (forall (requiresGradient :: RequiresGradient).
KnownRequiresGradient requiresGradient =>
RequiresGradient
requiresGradientVal @requiresGradient)

-- >>> :kind! GetGradients ('Gradient 'WithGradient)
-- GetGradients ('Gradient 'WithGradient) :: [Gradient RequiresGradient]
-- = '[ 'Gradient 'WithGradient]
-- >>> :kind! GetGradients '[ 'Gradient 'WithoutGradient, 'Gradient 'WithGradient]
-- GetGradients '[ 'Gradient 'WithoutGradient, 'Gradient 'WithGradient] :: [Gradient
--                                                      RequiresGradient]
-- = '[ 'Gradient 'WithoutGradient, 'Gradient 'WithGradient]
-- >>> :kind! GetGradients ('Just ('Gradient 'WithGradient))
-- GetGradients ('Just ('Gradient 'WithGradient)) :: [Gradient RequiresGradient]
-- = '[ 'Gradient 'WithGradient]
type GetGradients :: k -> [Gradient RequiresGradient]
type family GetGradients f where
  GetGradients (a :: Gradient RequiresGradient) = '[a]
  GetGradients (f g) = Concat (GetGradients f) (GetGradients g)
  GetGradients _ = '[]