You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.cpp 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * \file imperative/python/src/grad.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./grad.h"
  12. #include "megbrain/imperative/backward_graph_opt.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/proxy_graph_detail.h"
  15. #include "megbrain/imperative/resource_manager.h"
  16. #include "megbrain/utils/mempool.h"
  17. #include "range/v3/all.hpp"
  18. #include "./helper.h"
  19. #include "./transformation.h"
  20. namespace py = pybind11;
  21. namespace views = ranges::views;
  22. namespace mgb::imperative::python {
  23. namespace {
  24. std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map;
  25. }
  26. GradKeyWrapper::GradKeyWrapper() {}
  27. void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
  28. if (nargs != 2) {
  29. throw py::type_error("expect 2 arguments");
  30. }
  31. auto* tw = TensorWrapper::try_cast(args[0]);
  32. if (!tw) {
  33. throw py::type_error("argument 1 must be Tensor");
  34. }
  35. py::object callback;
  36. if (args[1] != Py_None) {
  37. callback = py::reinterpret_borrow<py::object>(args[1]);
  38. }
  39. GenericFunction generic_callback = [=](Span<ValueRef> inputs) -> ValueRefList {
  40. mgb_assert(inputs.size() == 1);
  41. if (callback) {
  42. callback(TensorWrapper::make(py_tensor_type, inputs[0]));
  43. }
  44. return {};
  45. };
  46. auto attached_value = imperative::apply(
  47. AttachGrad(m_key), tw->m_tensor->data(),
  48. FunctionValue::make(generic_callback))[0];
  49. tw->m_tensor->reset(attached_value);
  50. }
  51. void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) {
  52. std::vector<ValueRef> args;
  53. mgb_assert(tensors.size() == grads.size());
  54. for (auto&& tensor : tensors) {
  55. args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
  56. }
  57. for (auto&& grad : grads) {
  58. args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data());
  59. }
  60. imperative::apply(GradBackward(self->m_key), {args.data(), args.size()});
  61. }
  62. pybind11::function GradKeyWrapper::get_backward_closure(
  63. GradKeyWrapper* self, py::list tensors) {
  64. std::vector<ValueRef> args;
  65. for (auto&& tensor : tensors) {
  66. args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
  67. }
  68. auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0];
  69. auto closure = closure_value.as_ref<FunctionValue>();
  70. auto py_function = [closure](std::vector<TensorWrapper*> tensors) {
  71. std::vector<ValueRef> args;
  72. for (auto* tw : tensors) {
  73. args.push_back(tw->m_tensor->data());
  74. }
  75. (*closure)(args);
  76. };
  77. return pybind11::cpp_function(py_function);
  78. }
  79. PyObject* GradKeyWrapper::get_name() {
  80. return py::cast(m_name).release().ptr();
  81. }
  82. void GradKeyWrapper::set_name(py::handle name) {
  83. m_name = py::cast<std::string>(name);
  84. if (m_key) {
  85. m_key->name(m_name);
  86. }
  87. }
  88. PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
  89. if (nargs != 1) {
  90. PyErr_SetString(PyExc_TypeError, "expect 1 argument");
  91. return nullptr;
  92. }
  93. auto* tw = TensorWrapper::try_cast(args[0]);
  94. if (!tw) {
  95. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  96. return nullptr;
  97. }
  98. if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0]
  99. .cast<BoolValue>()) {
  100. Py_RETURN_TRUE;
  101. }
  102. Py_RETURN_FALSE;
  103. }
  104. void GradKeyWrapper::enter() {
  105. m_transformation = std::make_shared<GradTransformation>();
  106. m_key = m_transformation->key();
  107. m_key->name(m_name);
  108. grad_key_map[m_key] = this;
  109. TransformationManager::get_instance().register_at<TransformationManager::Grad>(
  110. m_transformation);
  111. }
  112. void GradKeyWrapper::exit() {
  113. TransformationManager::get_instance().unregister<TransformationManager::Grad>(
  114. m_transformation);
  115. grad_key_map.erase(m_key);
  116. m_key = {};
  117. m_transformation.reset();
  118. }
  119. void GradKeyWrapper::suppress() {
  120. m_transformation->suppress();
  121. }
  122. void GradKeyWrapper::resume() {
  123. m_transformation->resume();
  124. }
  125. GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
  126. return grad_key_map.at(key);
  127. }
  128. GradKeyWrapper::~GradKeyWrapper() {}
  129. } // namespace mgb::imperative::python