You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.h 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /**
  2. * \file imperative/python/src/grad.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "./tensor.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include <megbrain/utils/small_vector.h>
  15. #include <memory>
  16. namespace mgb::imperative::python {
  17. apply_result_t apply_grad(ApplyContext& ctx);
  18. struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
  19. std::string name;
  20. bool active = true;
  21. GradInfo::head_t free_vars_head;
  22. std::vector<std::weak_ptr<GradFn>> tape;
  23. int priority = 0;
  24. ~GradKey();
  25. void attach(Tensor* tensor, pybind11::object callback);
  26. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  27. void cleanup();
  28. bool is_blocked() const {
  29. return priority < sm_min_priority;
  30. }
  31. inline static bool allow_higher_order_directive = false;
  32. private:
  33. static int sm_min_priority;
  34. };
  35. struct GradKeyWrapper {
  36. using wrap_t = pyext17::wrap<GradKeyWrapper>;
  37. static constexpr auto tp_name = pybind11::detail::_("GradKey");
  38. std::shared_ptr<GradKey> m_key;
  39. inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
  40. PyObject* get_name();
  41. void set_name(pybind11::handle name);
  42. PyObject* get_priority();
  43. void set_priority(pybind11::handle priority);
  44. void attach(PyObject*const* args, size_t nargs);
  45. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  46. PyObject* is_attached_to(PyObject*const* args, size_t nargs);
  47. };
  48. struct BackwardContext {
  49. PyTypeObject* pytype = nullptr;
  50. auto wrap_tensor(std::shared_ptr<Tensor> t) {
  51. if (pytype) {
  52. return TensorWrapper::make(pytype, std::move(t));
  53. }
  54. return TensorWrapper::make(std::move(t));
  55. }
  56. auto wrap_tensor(Tensor* t) {
  57. return wrap_tensor(t->shared_from_this());
  58. }
  59. };
  60. struct CustomBackward {
  61. using BackwardFn = std::function<apply_result_t(BackwardContext&, Tensor*const*, size_t)>;
  62. BackwardFn m_backward;
  63. SmallVector<bool, 8> m_input_has_grad;
  64. struct OutputAttr {bool requires_grad = true, captured = true;};
  65. SmallVector<OutputAttr> m_output_attrs;
  66. public:
  67. template<typename T, typename R>
  68. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  69. size_t nargs = grads.size();
  70. Tensor* args[nargs];
  71. for (size_t i = 0; i < nargs; ++i) {
  72. args[i] = grads[i];
  73. }
  74. auto ret = m_backward(ctx, args, nargs);
  75. for (size_t i = 0; i < ret.size(); ++i) {
  76. if (auto&& t = ret[i]) {
  77. receiver(i, std::move(t));
  78. }
  79. }
  80. }
  81. bool input_has_grad(size_t i) {return m_input_has_grad[i];}
  82. bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;}
  83. bool output_captured(size_t i) {return m_output_attrs[i].captured;}
  84. class Maker {
  85. bool output_size_set = false, input_has_grad_initialized = false;
  86. CustomBackward& target;
  87. ApplyContext& ctx;
  88. void init_input_has_grad() {
  89. if (!input_has_grad_initialized) {
  90. input_has_grad_initialized = true;
  91. target.m_input_has_grad.resize(ctx.nargs, true);
  92. }
  93. }
  94. public:
  95. Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {}
  96. template<typename F>
  97. Maker& backward(F&& f) {
  98. mgb_assert(!target.m_backward);
  99. target.m_backward = std::forward<F>(f);
  100. return *this;
  101. }
  102. // mandatory
  103. Maker& output_size(size_t sz) {
  104. mgb_assert(!output_size_set);
  105. output_size_set = true;
  106. target.m_output_attrs.resize(sz);
  107. return *this;
  108. }
  109. // optional, defaults to all true
  110. Maker& input_has_grad(size_t i, bool v) {
  111. init_input_has_grad();
  112. target.m_input_has_grad.at(i) = v;
  113. return *this;
  114. }
  115. // optional, defaults to all true
  116. Maker& output_requires_grad(size_t i, bool v) {
  117. target.m_output_attrs.at(i).requires_grad = v;
  118. return *this;
  119. }
  120. // optional, defaults to all true
  121. Maker& output_captured(size_t i, bool v) {
  122. target.m_output_attrs.at(i).captured = v;
  123. return *this;
  124. }
  125. void finalize() {
  126. mgb_assert(output_size_set);
  127. init_input_has_grad();
  128. }
  129. };
  130. Maker maker(ApplyContext& ctx) {return {*this, ctx};}
  131. };
  132. using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>;
  133. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
  134. inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
  135. return !ctx.args[i]->m_grad_info_dict.empty();
  136. }
  137. struct GradRuleFallback : std::exception {};
  138. template<typename T>
  139. bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
  140. return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
  141. }
  142. } // namespace mgb::imperative::python
  143. namespace pybind11::detail {
  144. template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
  145. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台