You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_override.cpp 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. /**
  2. * \file imperative/python/src/grad_override.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./grad.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. namespace mgb::imperative::python {
  14. namespace {
  15. std::shared_ptr<Tensor> get_shape(Tensor* x) {
  16. static auto op = GetVarShape::make();
  17. return python::apply(op, x)[0];
  18. }
  19. std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
  20. static auto op = Reduce::make();
  21. return python::apply(op, x, s)[0];
  22. }
  23. std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
  24. static auto op = Reshape::make();
  25. return python::apply(op, x, s)[0];
  26. }
  27. std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
  28. static auto op = Broadcast::make();
  29. return python::apply(op, x, s)[0];
  30. }
  31. std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
  32. HostTensorND scalar{cn, {{1}, dtype}};
  33. std::memset(scalar.raw_ptr(), 0, dtype.size());
  34. interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
  35. auto&& t = std::make_shared<Tensor>(handle);
  36. auto res = broadcast_to(t.get(), shape);
  37. return res;
  38. }
  39. std::optional<apply_result_t> elemwise_grad_rule(
  40. ApplyContext& ctx, CustomBackward::Maker& maker) {
  41. auto& op = ctx.op->cast_final_safe<Elemwise>();
  42. if (op.mode == Elemwise::Mode::ADD) {
  43. mgb_assert(ctx.nargs == 2);
  44. std::array<std::shared_ptr<Tensor>, 2> input_shapes;
  45. for (size_t i = 0; i < 2; ++i) {
  46. if (input_requires_grad(ctx, i)) {
  47. input_shapes[i] = get_shape(ctx.args[i]);
  48. }
  49. }
  50. maker.output_size(1).output_captured(0, false);
  51. maker.backward([shapes = std::move(input_shapes)](
  52. BackwardContext&, Tensor* const* grads, size_t ngrads) {
  53. mgb_assert(ngrads == 1);
  54. Tensor* grad = grads[0];
  55. apply_result_t ret(2);
  56. if (!grad) {
  57. return ret;
  58. }
  59. for (size_t i = 0; i < 2; ++i) {
  60. if (shapes[i]) {
  61. ret[i] = reduce_to(grad, shapes[i].get());
  62. }
  63. }
  64. return ret;
  65. });
  66. return apply(ctx);
  67. }
  68. return {};
  69. }
  70. std::optional<apply_result_t> reshape_grad_rule(
  71. ApplyContext& ctx, CustomBackward::Maker& maker) {
  72. mgb_assert(ctx.nargs == 2);
  73. std::array<std::shared_ptr<Tensor>, 2> input_shapes;
  74. for (size_t i = 0; i < 2; ++i) {
  75. if (input_requires_grad(ctx, i)) {
  76. input_shapes[i] = get_shape(ctx.args[i]);
  77. }
  78. }
  79. maker.output_size(1).output_captured(0, false);
  80. maker.backward([shapes = std::move(input_shapes)](
  81. BackwardContext&, Tensor* const* grads, size_t ngrads) {
  82. mgb_assert(ngrads == 1);
  83. Tensor* grad = grads[0];
  84. apply_result_t ret(2);
  85. if (!grad) {
  86. return ret;
  87. }
  88. for (size_t i = 0; i < 2; ++i) {
  89. if (shapes[i]) {
  90. ret[i] = reshape_to(grad, shapes[i].get());
  91. }
  92. }
  93. return ret;
  94. });
  95. return apply(ctx);
  96. }
  97. std::optional<apply_result_t> subtensor_grad_rule(
  98. ApplyContext& ctx, CustomBackward::Maker& maker) {
  99. auto&& op = ctx.op->cast_final_safe<Subtensor>();
  100. auto&& grad_op = SetSubtensor::make(op.items);
  101. SmallVector<std::shared_ptr<Tensor>> inputs;
  102. if (input_requires_grad(ctx, 0)) {
  103. inputs.push_back(get_shape(ctx.args[0]));
  104. for (size_t i = 1; i < ctx.nargs; ++i) {
  105. inputs.push_back(ctx.args[i]->copy());
  106. }
  107. }
  108. maker.output_size(1).output_captured(0, false);
  109. maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)](
  110. BackwardContext&, Tensor* const* grads, size_t ngrads) {
  111. mgb_assert(ngrads == 1);
  112. Tensor* grad = grads[0];
  113. apply_result_t ret(1);
  114. if (grad && inputs[0]) {
  115. SmallVector<Tensor*> args_(inputs.size() + 1);
  116. auto&& zeros = make_empty_tensor(
  117. grad->comp_node(), inputs[0].get(), grad->dtype());
  118. args_[0] = zeros.get();
  119. args_[1] = grad;
  120. for (size_t i = 1; i < inputs.size(); ++i) {
  121. args_[i + 1] = inputs[i].get();
  122. }
  123. ret[0] = python::apply(grad_op_, args_)[0];
  124. }
  125. return ret;
  126. });
  127. return apply(ctx);
  128. }
  129. std::optional<apply_result_t> indexingMultiAxisVec_grad_rule(
  130. ApplyContext& ctx, CustomBackward::Maker& maker) {
  131. auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
  132. auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
  133. SmallVector<std::shared_ptr<Tensor>> inputs;
  134. if (input_requires_grad(ctx, 0)) {
  135. inputs.push_back(get_shape(ctx.args[0]));
  136. for (size_t i = 1; i < ctx.nargs; ++i) {
  137. inputs.push_back(ctx.args[i]->copy());
  138. }
  139. }
  140. maker.output_size(1).output_captured(0, false);
  141. maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)](
  142. BackwardContext&, Tensor* const* grads, size_t ngrads) {
  143. mgb_assert(ngrads == 1);
  144. Tensor* grad = grads[0];
  145. apply_result_t ret(1);
  146. if (grad && inputs[0]) {
  147. SmallVector<Tensor*> args_(inputs.size() + 1);
  148. auto&& zeros = make_empty_tensor(
  149. grad->comp_node(), inputs[0].get(), grad->dtype());
  150. args_[0] = zeros.get();
  151. args_[1] = grad;
  152. for (size_t i = 1; i < inputs.size(); ++i) {
  153. args_[i + 1] = inputs[i].get();
  154. }
  155. ret[0] = python::apply(grad_op_, args_)[0];
  156. }
  157. return ret;
  158. });
  159. return apply(ctx);
  160. }
  161. std::optional<apply_result_t> reduce_grad_rule(
  162. ApplyContext& ctx, CustomBackward::Maker& maker) {
  163. auto& op = ctx.op->cast_final_safe<Reduce>();
  164. if (op.mode == Reduce::Mode::SUM) {
  165. if (ctx.nargs != 1) {
  166. return {};
  167. }
  168. std::array<std::shared_ptr<Tensor>, 1> input_shapes;
  169. if (input_requires_grad(ctx, 0)) {
  170. input_shapes[0] = get_shape(ctx.args[0]);
  171. }
  172. maker.output_size(1).output_captured(0, false);
  173. maker.backward([shapes = std::move(input_shapes)](
  174. BackwardContext&, Tensor* const* grads, size_t ngrads) {
  175. mgb_assert(ngrads == 1);
  176. Tensor* grad = grads[0];
  177. apply_result_t ret(1);
  178. if (grad && shapes[0]) {
  179. ret[0] = broadcast_to(grad, shapes[0].get());
  180. }
  181. return ret;
  182. });
  183. return apply(ctx);
  184. }
  185. return {};
  186. }
  187. std::optional<apply_result_t> addAxis_grad_rule(
  188. ApplyContext& ctx, CustomBackward::Maker& maker) {
  189. auto&& op = ctx.op->cast_final_safe<AddAxis>();
  190. mgb_assert(ctx.nargs == 1);
  191. bool flag = input_requires_grad(ctx, 0);
  192. auto&& grad_op = RemoveAxis::make(op.axis);
  193. std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
  194. maker.output_size(1).output_captured(0, false);
  195. maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](
  196. BackwardContext&, Tensor* const* grads, size_t ngrads) {
  197. mgb_assert(ngrads == 1);
  198. Tensor* grad = grads[0];
  199. apply_result_t ret(1);
  200. if (grad && flag_) {
  201. ret[0] = python::apply(grad_op_, grad)[0];
  202. }
  203. return ret;
  204. });
  205. return apply(ctx);
  206. }
  207. std::optional<apply_result_t> removeAxis_grad_rule(
  208. ApplyContext& ctx, CustomBackward::Maker& maker) {
  209. auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
  210. mgb_assert(ctx.nargs == 1);
  211. bool flag = input_requires_grad(ctx, 0);
  212. auto&& grad_op = AddAxis::make(op.axis);
  213. std::sort(grad_op->axis.begin(), grad_op->axis.end());
  214. maker.output_size(1).output_captured(0, false);
  215. maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](
  216. BackwardContext&, Tensor* const* grads, size_t ngrads) {
  217. mgb_assert(ngrads == 1);
  218. Tensor* grad = grads[0];
  219. apply_result_t ret(1);
  220. if (grad && flag_) {
  221. ret[0] = python::apply(grad_op_, grad)[0];
  222. }
  223. return ret;
  224. });
  225. return apply(ctx);
  226. }
  227. std::optional<apply_result_t> fastpathcopy_grad_rule(
  228. ApplyContext& ctx, CustomBackward::Maker& maker) {
  229. mgb_assert(ctx.nargs == 1);
  230. maker.output_size(1).output_captured(0, false);
  231. maker.backward([](BackwardContext&, Tensor* const* grads, size_t ngrads) {
  232. mgb_assert(ngrads == 1);
  233. Tensor* grad = grads[0];
  234. apply_result_t ret(1);
  235. if (grad) {
  236. ret[0] = grad->shared_from_this();
  237. }
  238. return ret;
  239. });
  240. return apply(ctx);
  241. }
  242. struct Init {
  243. Init() {
  244. auto& reg = grad_rule_registry();
  245. reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
  246. reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
  247. reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
  248. reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
  249. reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
  250. reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
  251. reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
  252. reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
  253. }
  254. } _;
  255. } // namespace
  256. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台