{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}

module Torch.GraduallyTyped.Autograd where

import Data.Kind (Type)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import Torch.GraduallyTyped.Tensor (Tensor)
import Torch.Internal.Cast (cast2)
import qualified Torch.Internal.Managed.Autograd as ATen

class HasGrad parameters where
  type Gradients parameters :: Type
  type Loss parameters :: Type

  -- | calculate gradients of a zero-dimensional tensor with respect to a list of independent tensor parameters
  grad :: Loss parameters -> parameters -> Gradients parameters

instance HasGrad (Tensor ('Gradient 'WithGradient) layout device dataType shape) where
  type
    Gradients (Tensor ('Gradient 'WithGradient) layout device dataType shape) =
      Tensor ('Gradient 'WithoutGradient) layout device dataType shape
  type
    Loss (Tensor ('Gradient 'WithGradient) layout device dataType shape) =
      Tensor ('Gradient 'WithoutGradient) layout device dataType shape

  grad :: Loss
  (Tensor ('Gradient 'WithGradient) layout device dataType shape)
-> Tensor ('Gradient 'WithGradient) layout device dataType shape
-> Gradients
     (Tensor ('Gradient 'WithGradient) layout device dataType shape)
grad Loss
  (Tensor ('Gradient 'WithGradient) layout device dataType shape)
loss Tensor ('Gradient 'WithGradient) layout device dataType shape
parameter = forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
ATen.grad Loss
  (Tensor ('Gradient 'WithGradient) layout device dataType shape)
loss [Tensor ('Gradient 'WithGradient) layout device dataType shape
parameter]