{-# LANGUAGE DataKinds #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE OverloadedStrings #-} module Torch.Internal.Unmanaged.Autograd where import Foreign.Ptr import qualified Language.C.Inline.Cpp as C import qualified Language.C.Inline.Cpp.Unsafe as C import qualified Language.C.Inline.Context as C import qualified Language.C.Types as C import Foreign.C.Types (CBool) import Torch.Internal.Type C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } C.include "<vector>" C.include "<torch/types.h>" C.include "<torch/csrc/autograd/variable.h>" C.include "<torch/csrc/autograd/engine.h>" C.include "<ATen/core/functional.h>" grad :: Ptr Tensor -> Ptr TensorList -> IO (Ptr TensorList) grad :: Ptr Tensor -> Ptr TensorList -> IO (Ptr TensorList) grad Ptr Tensor y Ptr TensorList inputs = [C.throwBlock| std::vector<at::Tensor>* { torch::autograd::Variable y = *$(at::Tensor* y); const auto & inputs = *$(std::vector<at::Tensor>* inputs); torch::autograd::edge_list roots { torch::autograd::impl::gradient_edge(y) }; if (!roots[0].function) { throw std::runtime_error("Differentiated tensor not require grad"); } if (y.numel() != 1) { throw std::runtime_error("Differentiated tensor has more than a single element"); } torch::autograd::variable_list grads { torch::ones_like(y) }; torch::autograd::edge_list output_edges; output_edges.reserve(inputs.size()); for (torch::autograd::Variable input : inputs) { const auto output_nr = input.output_nr(); auto grad_fn = input.grad_fn(); if (!grad_fn) { grad_fn = torch::autograd::impl::try_get_grad_accumulator(input); } if (!input.requires_grad()) { throw std::runtime_error("One of the differentiated Tensors does not require grad"); } if (!grad_fn) { output_edges.emplace_back(); } else { output_edges.emplace_back(grad_fn, output_nr); } } auto & engine = torch::autograd::Engine::get_default_engine(); auto outputs = engine.execute(roots, grads, /*keep_graph=*/true, /*create_graph=*/false, /*accumulate_grad=*/false, // https://github.com/pytorch/pytorch/pull/46855 // https://github.com/pytorch/pytorch/issues/46373 output_edges); return new std::vector<at::Tensor>(at::fmap<at::Tensor>(outputs)); }|] makeIndependent :: Ptr Tensor -> CBool -> IO (Ptr Tensor) makeIndependent :: Ptr Tensor -> CBool -> IO (Ptr Tensor) makeIndependent Ptr Tensor tensor CBool requires_grad = [C.throwBlock| at::Tensor* { return new at::Tensor($(at::Tensor* tensor)->detach().set_requires_grad($(bool requires_grad))); }|] dropVariable :: Ptr Tensor -> IO (Ptr Tensor) dropVariable :: Ptr Tensor -> IO (Ptr Tensor) dropVariable Ptr Tensor t = [C.throwBlock| at::Tensor* { auto ret = $(at::Tensor* t)->detach(); ret.unsafeGetTensorImpl()->set_autograd_meta(nullptr); return new at::Tensor(ret); }|]