You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuse_basic.cc 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/pass/fuse_basic.h"
  17. #include "backend/optimizer/pass/fuse_graph_kernel.h"
  18. #include <memory>
  19. #include <algorithm>
  20. #include <unordered_set>
  21. #include <unordered_map>
  22. #include <vector>
  23. #include <string>
  24. #include "frontend/operator/ops.h"
  25. #include "utils/utils.h"
  26. #include "utils/graph_utils.h"
  27. #include "backend/optimizer/common/helper.h"
  28. #include "backend/session/anf_runtime_algorithm.h"
  29. #include "vm/segment_runner.h"
  30. #include "debug/draw.h"
  31. #include "debug/anf_ir_dump.h"
  32. #include "ir/func_graph_cloner.h"
  33. namespace mindspore {
  34. namespace opt {
  35. namespace {
  36. std::vector<PrimitivePtr> get_fusable_basic_ops(bool is_before_kernel_select) {
  37. std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
  38. prim::kPrimExpandDims};
  39. if (!is_before_kernel_select) {
  40. fusable_basic_ops.push_back(prim::kPrimCast);
  41. }
  42. return fusable_basic_ops;
  43. }
  44. IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
  45. const AnfNodePtr &node) {
  46. if (cur_node == node) {
  47. return FOLLOW;
  48. }
  49. if (!IsPrimitiveCNode(node)) {
  50. return EXCLUDE;
  51. }
  52. auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select);
  53. bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
  54. [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
  55. return is_fusable ? FOLLOW : EXCLUDE;
  56. }
  57. std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
  58. GraphKernelInfo info;
  59. info.is_before_kernel_select = is_before_kernel_select;
  60. // Search fusable nodes according input direction.
  61. auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
  62. auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
  63. if (used_nodes.size() > 1) {
  64. used_nodes = RemoveCircle(used_nodes, false);
  65. }
  66. TopoSortForNodeList(&used_nodes);
  67. return used_nodes;
  68. }
  69. void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) {
  70. AnfNodeSet outputs_set;
  71. for (auto out : *outputs) {
  72. outputs_set.insert(out);
  73. }
  74. AnfNodePtrList vir_outputs;
  75. std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
  76. auto fg_outputs = fg->output();
  77. if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) {
  78. auto cnode = fg_outputs->cast<CNodePtr>();
  79. for (size_t i = 1; i < cnode->size(); ++i) {
  80. vir_outputs.push_back(cnode->input(i));
  81. }
  82. } else {
  83. vir_outputs.push_back(fg_outputs);
  84. }
  85. if (vir_outputs.size() != outputs->size()) {
  86. MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
  87. }
  88. bool has_erase_outs = false;
  89. size_t index = -1;
  90. for (auto it = outputs->begin(); it != outputs->end();) {
  91. index++;
  92. auto out = *it;
  93. eqv[out] = vir_outputs[index];
  94. auto users = mng->node_users()[out];
  95. bool is_only_control_depend_use = true;
  96. std::vector<size_t> control_depend_use_index;
  97. std::vector<CNodePtr> control_depend_nodes;
  98. AnfNodePtr use_out = nullptr;
  99. for (auto &user : users) {
  100. auto use_node = user.first;
  101. if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
  102. is_only_control_depend_use = false;
  103. continue;
  104. }
  105. if (outputs_set.count(use_node) != 0) {
  106. use_out = use_node;
  107. }
  108. if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) {
  109. control_depend_nodes.push_back(use_node->cast<CNodePtr>());
  110. control_depend_use_index.push_back(user.second);
  111. }
  112. }
  113. if (is_only_control_depend_use && !control_depend_nodes.empty()) {
  114. MS_EXCEPTION_IF_NULL(use_out);
  115. it = outputs->erase(it);
  116. for (size_t i = 0; i < control_depend_nodes.size(); ++i) {
  117. auto control_depend_node = control_depend_nodes[i];
  118. std::vector<AnfNodePtr> new_control_depend_inputs;
  119. for (size_t j = 0; j < control_depend_node->size(); ++j) {
  120. if (j == control_depend_use_index[i]) {
  121. new_control_depend_inputs.push_back(use_out);
  122. } else {
  123. new_control_depend_inputs.push_back(control_depend_node->input(j));
  124. }
  125. }
  126. auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
  127. mng->Replace(control_depend_node, new_control_depend);
  128. has_erase_outs = true;
  129. }
  130. } else {
  131. it++;
  132. }
  133. }
  134. if (!has_erase_outs) {
  135. return;
  136. }
  137. AnfNodePtr fg_new_output;
  138. if (outputs->size() > 1) {
  139. std::vector<AnfNodePtr> output_args;
  140. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  141. (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args),
  142. [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
  143. // Set output for AnfGraph
  144. fg_new_output = fg->NewCNode(output_args);
  145. } else {
  146. fg_new_output = eqv[(*outputs)[0]];
  147. }
  148. fg->set_output(fg_new_output, true);
  149. }
  150. void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::vector<AnfNodePtr> &todos,
  151. std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) {
  152. auto mng = kernel_graph->manager();
  153. for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
  154. auto node = (*iter)->cast<CNodePtr>();
  155. if (node == nullptr) {
  156. continue;
  157. }
  158. if (fused_ops->count(node)) {
  159. continue;
  160. }
  161. auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select);
  162. bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
  163. [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
  164. if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
  165. continue;
  166. }
  167. auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
  168. if (fuse_nodes.size() <= 1) {
  169. continue;
  170. }
  171. FuncGraphPtr fg;
  172. AnfNodePtrList inputs;
  173. AnfNodePtrList outputs;
  174. std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
  175. RemoveControlDependOut(fg, &outputs, mng);
  176. auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select);
  177. ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
  178. // Set graph kernel attr
  179. std::string fuse_op_name = "";
  180. for (auto &fuse_node : fuse_nodes) {
  181. fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
  182. }
  183. fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
  184. fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
  185. }
  186. }
  187. } // namespace
  188. void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
  189. MS_EXCEPTION_IF_NULL(kernel_graph);
  190. auto mng = kernel_graph->manager();
  191. if (mng == nullptr) {
  192. mng = Manage(kernel_graph, true);
  193. kernel_graph->set_manager(mng);
  194. }
  195. std::unordered_set<AnfNodePtr> fused_ops;
  196. auto todos = TopoSort(kernel_graph->get_return());
  197. std::reverse(todos.begin(), todos.end());
  198. FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select);
  199. }
  200. } // namespace opt
  201. } // namespace mindspore