You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.h 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. /**
  2. * \file imperative/python/src/grad.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "./tensor.h"
  13. #include <megbrain/utils/small_vector.h>
  14. #include <memory>
  15. namespace mgb::imperative::python {
  16. apply_result_t apply_grad(ApplyContext& ctx);
  17. struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
  18. std::string name;
  19. bool active = true;
  20. GradInfo::head_t free_vars_head;
  21. std::vector<std::weak_ptr<GradFn>> tape;
  22. ~GradKey();
  23. void attach(Tensor* tensor, pybind11::object callback);
  24. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  25. void cleanup();
  26. };
  27. struct GradKeyWrapper {
  28. using wrap_t = pyext17::wrap<GradKeyWrapper>;
  29. static constexpr auto tp_name = pybind11::detail::_("GradKey");
  30. std::shared_ptr<GradKey> m_key;
  31. inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
  32. PyObject* get_name();
  33. void set_name(pybind11::handle name);
  34. void attach(PyObject*const* args, size_t nargs);
  35. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  36. PyObject* is_attached_to(PyObject*const* args, size_t nargs);
  37. };
  38. struct BackwardContext {
  39. PyTypeObject* pytype = nullptr;
  40. auto wrap_tensor(std::shared_ptr<Tensor> t) {
  41. if (pytype) {
  42. return TensorWrapper::make(pytype, std::move(t));
  43. }
  44. return TensorWrapper::make(std::move(t));
  45. }
  46. auto wrap_tensor(Tensor* t) {
  47. return wrap_tensor(t->shared_from_this());
  48. }
  49. };
  50. struct CustomBackward {
  51. using BackwardFn = std::function<apply_result_t(BackwardContext&, Tensor*const*, size_t)>;
  52. BackwardFn m_backward;
  53. SmallVector<bool, 8> m_input_has_grad;
  54. struct OutputAttr {bool requires_grad = true, captured = true;};
  55. SmallVector<OutputAttr> m_output_attrs;
  56. public:
  57. template<typename T, typename R>
  58. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  59. size_t nargs = grads.size();
  60. Tensor* args[nargs];
  61. for (size_t i = 0; i < nargs; ++i) {
  62. args[i] = grads[i];
  63. }
  64. auto ret = m_backward(ctx, args, nargs);
  65. for (size_t i = 0; i < ret.size(); ++i) {
  66. if (auto&& t = ret[i]) {
  67. receiver(i, std::move(t));
  68. }
  69. }
  70. }
  71. bool input_has_grad(size_t i) {return m_input_has_grad[i];}
  72. bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;}
  73. bool output_captured(size_t i) {return m_output_attrs[i].captured;}
  74. class Maker {
  75. bool output_size_set = false, input_has_grad_initialized = false;
  76. CustomBackward& target;
  77. ApplyContext& ctx;
  78. void init_input_has_grad() {
  79. if (!input_has_grad_initialized) {
  80. input_has_grad_initialized = true;
  81. target.m_input_has_grad.resize(ctx.nargs, true);
  82. }
  83. }
  84. public:
  85. Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {}
  86. template<typename F>
  87. Maker& backward(F&& f) {
  88. mgb_assert(!target.m_backward);
  89. target.m_backward = std::forward<F>(f);
  90. return *this;
  91. }
  92. // mandatory
  93. Maker& output_size(size_t sz) {
  94. mgb_assert(!output_size_set);
  95. output_size_set = true;
  96. target.m_output_attrs.resize(sz);
  97. return *this;
  98. }
  99. // optional, defaults to all true
  100. Maker& input_has_grad(size_t i, bool v) {
  101. init_input_has_grad();
  102. target.m_input_has_grad.at(i) = v;
  103. return *this;
  104. }
  105. // optional, defaults to all true
  106. Maker& output_requires_grad(size_t i, bool v) {
  107. target.m_output_attrs.at(i).requires_grad = v;
  108. return *this;
  109. }
  110. // optional, defaults to all true
  111. Maker& output_captured(size_t i, bool v) {
  112. target.m_output_attrs.at(i).captured = v;
  113. return *this;
  114. }
  115. void finalize() {
  116. mgb_assert(output_size_set);
  117. init_input_has_grad();
  118. }
  119. };
  120. Maker maker(ApplyContext& ctx) {return {*this, ctx};}
  121. };
  122. using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>;
  123. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
  124. inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
  125. return bool(ctx.args[i]->m_grad_info.grad_fn);
  126. }
  127. struct GradRuleFallback : std::exception {};
  128. template<typename T>
  129. bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
  130. return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
  131. }
  132. } // namespace mgb::imperative::python
  133. namespace pybind11::detail {
  134. template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
  135. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台