{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Autograd
  ( Torch.Typed.Autograd.HasGrad,
    Torch.Typed.Autograd.grad,
  )
where

import Data.Kind
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Autograd as LibTorch
import qualified Torch.Tensor as D
import Torch.Typed.Parameter
import Torch.Typed.Tensor

class HasGrad a b | a -> b where
  -- | calculate gradients of a zero-dimensional tensor with respect to a list of parameters
  grad :: forall dtype device. Tensor device dtype '[] -> a -> b

  toDependent :: a -> b

-- instance HasGrad (Tensor device dtype shape) (Tensor device dtype shape) where
--   grad loss input = head . unsafePerformIO $ ATen.cast2
--     Torch.Managed.Autograd.grad
--     loss
--     [Torch.Typed.Autograd.toDependent input]
--   toDependent = id

instance HasGrad (Parameter device dtype shape) (Tensor device dtype shape) where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> Parameter device dtype shape -> Tensor device dtype shape
grad Tensor device dtype '[]
loss Parameter device dtype shape
input =
    forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
      forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
        ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
        Tensor device dtype '[]
loss
        [forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent Parameter device dtype shape
input]
  toDependent :: Parameter device dtype shape -> Tensor device dtype shape
toDependent = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
Torch.Typed.Parameter.toDependent

instance HasGrad (HList ('[] :: [Type])) (HList ('[] :: [Type])) where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList '[] -> HList '[]
grad Tensor device dtype '[]
_ = forall a. a -> a
id
  toDependent :: HList '[] -> HList '[]
toDependent = forall a. a -> a
id

instance
  ( HasGrad a b,
    HasGrad (HList as) (HList bs),
    ATen.Castable (HList (b ': bs)) [D.ATenTensor]
  ) =>
  HasGrad (HList (a ': as)) (HList (b ': bs))
  where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList (a : as) -> HList (b : bs)
grad Tensor device dtype '[]
loss HList (a : as)
inputs =
    forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
      forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
        ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
        Tensor device dtype '[]
loss
        (forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList (a : as)
inputs)
  toDependent :: HList (a : as) -> HList (b : bs)
toDependent (a
a :. HList as
as) =
    forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent a
a forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList as
as