{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
module Torch.GraduallyTyped.Autograd where
import Data.Kind (Type)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import Torch.GraduallyTyped.Tensor (Tensor)
import Torch.Internal.Cast (cast2)
import qualified Torch.Internal.Managed.Autograd as ATen
class HasGrad parameters where
type Gradients parameters :: Type
type Loss parameters :: Type
grad :: Loss parameters -> parameters -> Gradients parameters
instance HasGrad (Tensor ('Gradient 'WithGradient) layout device dataType shape) where
type
Gradients (Tensor ('Gradient 'WithGradient) layout device dataType shape) =
Tensor ('Gradient 'WithoutGradient) layout device dataType shape
type
Loss (Tensor ('Gradient 'WithGradient) layout device dataType shape) =
Tensor ('Gradient 'WithoutGradient) layout device dataType shape
grad :: Loss
(Tensor ('Gradient 'WithGradient) layout device dataType shape)
-> Tensor ('Gradient 'WithGradient) layout device dataType shape
-> Gradients
(Tensor ('Gradient 'WithGradient) layout device dataType shape)
grad Loss
(Tensor ('Gradient 'WithGradient) layout device dataType shape)
loss Tensor ('Gradient 'WithGradient) layout device dataType shape
parameter = forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
ATen.grad Loss
(Tensor ('Gradient 'WithGradient) layout device dataType shape)
loss [Tensor ('Gradient 'WithGradient) layout device dataType shape
parameter]