You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backward_graph_opt.cpp 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. /**
  2. * \file imperative/src/impl/backward_graph_opt.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/backward_graph_opt.h"
  12. #include "megbrain/imperative/ops/backward_graph.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. using namespace mgb;
  15. using namespace imperative;
  16. OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubgraph& src)
  17. : input_has_grad(src.output_mask) {
  18. if (src.graph.exprs.size() <= 1) {
  19. // backward graph only contains a single op
  20. backward = src.graph;
  21. save_for_backward = src.input_mask;
  22. return;
  23. }
  24. save_for_backward.resize(src.input_mask.size(), false);
  25. auto&& graph = src.graph;
  26. auto&& mask = src.input_mask;
  27. size_t input_size = src.output_mask.size();
  28. size_t output_size = (mask.size() - input_size) / 2;
  29. mgb_assert(input_size + output_size * 2 == mask.size());
  30. auto& fgraph = precomp;
  31. auto& bgraph = backward;
  32. // optimization: move ops (e.g. GetVarShape) to forward to
  33. // reduce memory footprint
  34. struct VInfo {
  35. bool appears_in_backward = false;
  36. };
  37. std::unordered_map<size_t, VInfo> vinfo;
  38. // step 1.1: ops not in whitelist must run in backward.
  39. // mark their inputs as always appears in backward
  40. for (auto&& [op, iv, ov] : graph.exprs) {
  41. if (!op->same_type<GetVarShape>()) {
  42. for (auto&& v : iv) {
  43. vinfo[v].appears_in_backward = true;
  44. }
  45. }
  46. }
  47. // step 1.2: inputs only available in backward (i.e. grads)
  48. // should be marked as always appears in backward
  49. for (size_t i = 0, j = 0; i < mask.size(); ++i) {
  50. if (!mask[i]) continue;
  51. if (i >= input_size + output_size) {
  52. vinfo[graph.inputs[j]].appears_in_backward = true;
  53. }
  54. ++j;
  55. }
  56. // step 2: try to move ops to forward, if not all their inputs
  57. // are marked always appears in backward (otherwise no memory saving)
  58. for (auto&& expr : graph.exprs) {
  59. auto&& [op, iv, ov] = expr;
  60. if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) {
  61. bgraph.exprs.push_back(expr);
  62. for (auto&& v : ov) {
  63. vinfo[v].appears_in_backward = true;
  64. }
  65. // logically should also mark all inputs as appears in backward
  66. // but clearly that's a no-op.
  67. } else {
  68. fgraph.exprs.push_back(expr);
  69. for (auto&& v : ov) {
  70. if (vinfo[v].appears_in_backward) {
  71. // appears_in_backward won't change after this point
  72. // so it is safe to set fgraph.outputs based on current value
  73. fgraph.outputs.push_back(v);
  74. }
  75. }
  76. }
  77. }
  78. // initialize remaining parts
  79. fgraph.constants = graph.constants;
  80. fgraph.inputs.reserve(input_size + output_size);
  81. for (size_t i = 0, j = 0; i < input_size + output_size; ++i) {
  82. if (!mask[i]) {
  83. fgraph.inputs.push_back(1000000000 + i);
  84. continue;
  85. }
  86. fgraph.inputs.push_back(graph.inputs[j++]);
  87. }
  88. bgraph.constants = graph.constants;
  89. bgraph.outputs = graph.outputs;
  90. bgraph.inputs = fgraph.outputs;
  91. for (size_t i = 0, j = 0; i < mask.size(); ++i) {
  92. if (mask[i]) {
  93. auto&& v = graph.inputs[j++];
  94. if (vinfo[v].appears_in_backward) {
  95. save_for_backward[i] = true;
  96. bgraph.inputs.push_back(v);
  97. }
  98. }
  99. }
  100. if (!fgraph.outputs.size()) {
  101. precomp = {};
  102. }
  103. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台