You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #include "./grad.h"
  2. #include "megbrain/imperative/backward_graph_opt.h"
  3. #include "megbrain/imperative/ops/autogen.h"
  4. #include "megbrain/imperative/proxy_graph_detail.h"
  5. #include "megbrain/imperative/resource_manager.h"
  6. #include "megbrain/utils/mempool.h"
  7. #include "range/v3/all.hpp"
  8. #include "./helper.h"
  9. #include "./transformation.h"
  10. namespace py = pybind11;
  11. namespace views = ranges::views;
  12. namespace mgb::imperative::python {
  13. namespace {
  14. std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map;
  15. }
  16. GradKeyWrapper::GradKeyWrapper() {}
  17. void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
  18. if (nargs != 2) {
  19. throw py::type_error("expect 2 arguments");
  20. }
  21. auto* tw = TensorWrapper::try_cast(args[0]);
  22. if (!tw) {
  23. throw py::type_error("argument 1 must be Tensor");
  24. }
  25. py::object callback;
  26. if (args[1] != Py_None) {
  27. callback = py::reinterpret_borrow<py::object>(args[1]);
  28. }
  29. GenericFunction generic_callback = [=](Span<ValueRef> inputs) -> ValueRefList {
  30. mgb_assert(inputs.size() == 1);
  31. if (callback) {
  32. callback(TensorWrapper::make(py_tensor_type, inputs[0]));
  33. }
  34. return {};
  35. };
  36. auto attached_value = imperative::apply(
  37. AttachGrad(m_key), tw->m_tensor->data(),
  38. FunctionValue::make(generic_callback))[0];
  39. tw->m_tensor->reset(attached_value);
  40. }
  41. void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) {
  42. std::vector<ValueRef> args;
  43. mgb_assert(tensors.size() == grads.size());
  44. for (auto&& tensor : tensors) {
  45. args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
  46. }
  47. for (auto&& grad : grads) {
  48. args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data());
  49. }
  50. imperative::apply(GradBackward(self->m_key), {args.data(), args.size()});
  51. }
  52. pybind11::function GradKeyWrapper::get_backward_closure(
  53. GradKeyWrapper* self, py::list tensors) {
  54. std::vector<ValueRef> args;
  55. for (auto&& tensor : tensors) {
  56. args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
  57. }
  58. auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0];
  59. auto closure = closure_value.as_ref<FunctionValue>();
  60. auto py_function = [closure](std::vector<TensorWrapper*> tensors) {
  61. std::vector<ValueRef> args;
  62. for (auto* tw : tensors) {
  63. args.push_back(tw->m_tensor->data());
  64. }
  65. (*closure)(args);
  66. };
  67. return pybind11::cpp_function(py_function);
  68. }
  69. PyObject* GradKeyWrapper::get_name() {
  70. return py::cast(m_name).release().ptr();
  71. }
  72. void GradKeyWrapper::set_name(py::handle name) {
  73. m_name = py::cast<std::string>(name);
  74. if (m_key) {
  75. m_key->name(m_name);
  76. }
  77. }
  78. PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
  79. if (nargs != 1) {
  80. PyErr_SetString(PyExc_TypeError, "expect 1 argument");
  81. return nullptr;
  82. }
  83. auto* tw = TensorWrapper::try_cast(args[0]);
  84. if (!tw) {
  85. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  86. return nullptr;
  87. }
  88. if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0]
  89. .cast<BoolValue>()) {
  90. Py_RETURN_TRUE;
  91. }
  92. Py_RETURN_FALSE;
  93. }
  94. void GradKeyWrapper::enter() {
  95. m_transformation = std::make_shared<GradTransformation>();
  96. m_key = m_transformation->key();
  97. m_key->name(m_name);
  98. grad_key_map[m_key] = this;
  99. m_transformation_guard =
  100. TransformationManager::get_instance()
  101. .register_at<TransformationManager::Grad>(m_transformation);
  102. }
  103. void GradKeyWrapper::exit() {
  104. m_transformation_guard.reset();
  105. grad_key_map.erase(m_key);
  106. m_key = {};
  107. m_transformation.reset();
  108. }
  109. void GradKeyWrapper::suppress() {
  110. m_transformation->suppress();
  111. }
  112. void GradKeyWrapper::resume() {
  113. m_transformation->resume();
  114. }
  115. GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
  116. return grad_key_map.at(key);
  117. }
  118. GradKeyWrapper::~GradKeyWrapper() {}
  119. } // namespace mgb::imperative::python