{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
module Torch.Typed.Autograd
( Torch.Typed.Autograd.HasGrad,
Torch.Typed.Autograd.grad,
)
where
import Data.Kind
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Autograd as LibTorch
import qualified Torch.Tensor as D
import Torch.Typed.Parameter
import Torch.Typed.Tensor
class HasGrad a b | a -> b where
grad :: forall dtype device. Tensor device dtype '[] -> a -> b
toDependent :: a -> b
instance HasGrad (Parameter device dtype shape) (Tensor device dtype shape) where
grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> Parameter device dtype shape -> Tensor device dtype shape
grad Tensor device dtype '[]
loss Parameter device dtype shape
input =
forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
Tensor device dtype '[]
loss
[forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent Parameter device dtype shape
input]
toDependent :: Parameter device dtype shape -> Tensor device dtype shape
toDependent = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
Torch.Typed.Parameter.toDependent
instance HasGrad (HList ('[] :: [Type])) (HList ('[] :: [Type])) where
grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList '[] -> HList '[]
grad Tensor device dtype '[]
_ = forall a. a -> a
id
toDependent :: HList '[] -> HList '[]
toDependent = forall a. a -> a
id
instance
( HasGrad a b,
HasGrad (HList as) (HList bs),
ATen.Castable (HList (b ': bs)) [D.ATenTensor]
) =>
HasGrad (HList (a ': as)) (HList (b ': bs))
where
grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList (a : as) -> HList (b : bs)
grad Tensor device dtype '[]
loss HList (a : as)
inputs =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
Tensor device dtype '[]
loss
(forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList (a : as)
inputs)
toDependent :: HList (a : as) -> HList (b : bs)
toDependent (a
a :. HList as
as) =
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent a
a forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList as
as