{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.Autograd where

import Foreign.ForeignPtr
import GHC.Generics
import System.IO.Unsafe
import Torch.Internal.Cast
import Torch.Internal.Class
import qualified Torch.Internal.Managed.Autograd
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Tensor

-- | Note: to create an `IndependentTensor` use `makeIndependent`;
-- | otherwise, Torch will complain the parameter does not require a gradient.
newtype IndependentTensor = IndependentTensor
  { IndependentTensor -> Tensor
toDependent :: Tensor
  }
  deriving (Int -> IndependentTensor -> ShowS
[IndependentTensor] -> ShowS
IndependentTensor -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IndependentTensor] -> ShowS
$cshowList :: [IndependentTensor] -> ShowS
show :: IndependentTensor -> String
$cshow :: IndependentTensor -> String
showsPrec :: Int -> IndependentTensor -> ShowS
$cshowsPrec :: Int -> IndependentTensor -> ShowS
Show, forall x. Rep IndependentTensor x -> IndependentTensor
forall x. IndependentTensor -> Rep IndependentTensor x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep IndependentTensor x -> IndependentTensor
$cfrom :: forall x. IndependentTensor -> Rep IndependentTensor x
Generic)

grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad Tensor
y [IndependentTensor]
inputs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
Torch.Internal.Managed.Autograd.grad Tensor
y (forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
inputs)

requiresGrad :: Tensor -> Bool
requiresGrad :: Tensor -> Bool
requiresGrad Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor
t

setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad Bool
flag Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor
t Bool
flag

makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent Tensor
tensor = Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
True

makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
requires_grad = Tensor -> IndependentTensor
IndependentTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
Torch.Internal.Managed.Autograd.makeIndependent Tensor
tensor Bool
requires_grad