You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.h 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. /**
  2. * \file imperative/python/src/grad.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "./tensor.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include <megbrain/utils/small_vector.h>
  15. #include <memory>
  16. #include <optional>
  17. namespace mgb::imperative::python {
  18. apply_result_t apply_grad(ApplyContext& ctx);
  19. struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
  20. std::string name;
  21. bool active = true;
  22. GradInfo::head_t free_vars_head;
  23. std::vector<std::weak_ptr<GradFn>> tape;
  24. int priority = 0;
  25. ~GradKey();
  26. void attach(Tensor* tensor, pybind11::object callback);
  27. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  28. void cleanup();
  29. bool is_blocked() const { return priority < sm_min_priority; }
  30. inline static bool allow_higher_order_directive = false;
  31. private:
  32. static int sm_min_priority;
  33. };
  34. struct GradKeyWrapper {
  35. using wrap_t = pyext17::wrap<GradKeyWrapper>;
  36. static constexpr auto tp_name = pybind11::detail::_("GradKey");
  37. std::shared_ptr<GradKey> m_key;
  38. inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
  39. PyObject* get_name();
  40. void set_name(pybind11::handle name);
  41. PyObject* get_priority();
  42. void set_priority(pybind11::handle priority);
  43. void attach(PyObject* const* args, size_t nargs);
  44. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  45. PyObject* is_attached_to(PyObject* const* args, size_t nargs);
  46. };
  47. struct BackwardContext {
  48. PyTypeObject* pytype = nullptr;
  49. auto wrap_tensor(std::shared_ptr<Tensor> t) {
  50. if (pytype) {
  51. return TensorWrapper::make(pytype, std::move(t));
  52. }
  53. return TensorWrapper::make(std::move(t));
  54. }
  55. auto wrap_tensor(Tensor* t) { return wrap_tensor(t->shared_from_this()); }
  56. };
  57. struct CustomBackward {
  58. using BackwardFn =
  59. std::function<apply_result_t(BackwardContext&, Tensor* const*, size_t)>;
  60. BackwardFn m_backward;
  61. SmallVector<bool, 8> m_input_has_grad;
  62. struct OutputAttr {
  63. bool requires_grad = true, captured = true;
  64. };
  65. SmallVector<OutputAttr> m_output_attrs;
  66. public:
  67. template <typename T, typename R>
  68. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  69. size_t nargs = grads.size();
  70. Tensor* args[nargs];
  71. for (size_t i = 0; i < nargs; ++i) {
  72. args[i] = grads[i];
  73. }
  74. auto ret = m_backward(ctx, args, nargs);
  75. for (size_t i = 0; i < ret.size(); ++i) {
  76. if (auto&& t = ret[i]) {
  77. receiver(i, std::move(t));
  78. }
  79. }
  80. }
  81. bool input_has_grad(size_t i) { return m_input_has_grad[i]; }
  82. bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; }
  83. bool output_captured(size_t i) { return m_output_attrs[i].captured; }
  84. class Maker {
  85. bool output_size_set = false, input_has_grad_initialized = false;
  86. CustomBackward& target;
  87. ApplyContext& ctx;
  88. void init_input_has_grad() {
  89. if (!input_has_grad_initialized) {
  90. input_has_grad_initialized = true;
  91. target.m_input_has_grad.resize(ctx.nargs, true);
  92. }
  93. }
  94. public:
  95. Maker(CustomBackward& target_, ApplyContext& ctx_)
  96. : target(target_), ctx(ctx_) {}
  97. template <typename F>
  98. Maker& backward(F&& f) {
  99. mgb_assert(!target.m_backward);
  100. target.m_backward = std::forward<F>(f);
  101. return *this;
  102. }
  103. // mandatory
  104. Maker& output_size(size_t sz) {
  105. mgb_assert(!output_size_set);
  106. output_size_set = true;
  107. target.m_output_attrs.resize(sz);
  108. return *this;
  109. }
  110. // optional, defaults to all true
  111. Maker& input_has_grad(size_t i, bool v) {
  112. init_input_has_grad();
  113. target.m_input_has_grad.at(i) = v;
  114. return *this;
  115. }
  116. // optional, defaults to all true
  117. Maker& output_requires_grad(size_t i, bool v) {
  118. target.m_output_attrs.at(i).requires_grad = v;
  119. return *this;
  120. }
  121. // optional, defaults to all true
  122. Maker& output_captured(size_t i, bool v) {
  123. target.m_output_attrs.at(i).captured = v;
  124. return *this;
  125. }
  126. void finalize() {
  127. mgb_assert(output_size_set);
  128. init_input_has_grad();
  129. }
  130. };
  131. Maker maker(ApplyContext& ctx) { return {*this, ctx}; }
  132. };
  133. using GradRuleFn = std::function<std::optional<apply_result_t>(
  134. ApplyContext&, CustomBackward::Maker&)>;
  135. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
  136. inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
  137. return !ctx.args[i]->m_grad_info_dict.empty();
  138. }
  139. struct GradRuleFallback : std::exception {};
  140. template <typename T>
  141. bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
  142. return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
  143. }
  144. } // namespace mgb::imperative::python
  145. namespace pybind11::detail {
  146. template <>
  147. struct type_caster<mgb::imperative::python::GradKeyWrapper>
  148. : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
  149. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台