You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuse_basic.cc 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/pass/fuse_basic.h"
  17. #include "backend/optimizer/pass/fuse_graph_kernel.h"
  18. #include <memory>
  19. #include <algorithm>
  20. #include <unordered_set>
  21. #include <unordered_map>
  22. #include <vector>
  23. #include <string>
  24. #include "base/core_ops.h"
  25. #include "ir/graph_utils.h"
  26. #include "backend/optimizer/common/helper.h"
  27. #include "backend/session/anf_runtime_algorithm.h"
  28. #include "vm/segment_runner.h"
  29. #include "debug/anf_ir_dump.h"
  30. #include "ir/func_graph_cloner.h"
  31. namespace mindspore {
  32. namespace opt {
  33. namespace {
  34. std::vector<PrimitivePtr> get_fusable_basic_ops(bool is_before_kernel_select) {
  35. std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
  36. prim::kPrimExpandDims};
  37. if (!is_before_kernel_select) {
  38. fusable_basic_ops.push_back(prim::kPrimCast);
  39. }
  40. return fusable_basic_ops;
  41. }
  42. IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
  43. const AnfNodePtr &node) {
  44. if (cur_node == node) {
  45. return FOLLOW;
  46. }
  47. if (!IsPrimitiveCNode(node)) {
  48. return EXCLUDE;
  49. }
  50. auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select);
  51. bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
  52. [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
  53. return is_fusable ? FOLLOW : EXCLUDE;
  54. }
  55. std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
  56. GraphKernelInfo info;
  57. info.is_before_kernel_select = is_before_kernel_select;
  58. // Search fusable nodes according input direction.
  59. auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
  60. auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
  61. if (used_nodes.size() > 1) {
  62. used_nodes = RemoveCircle(used_nodes, false);
  63. }
  64. TopoSortForNodeList(&used_nodes);
  65. return used_nodes;
  66. }
  67. void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) {
  68. AnfNodeSet outputs_set;
  69. for (auto out : *outputs) {
  70. outputs_set.insert(out);
  71. }
  72. AnfNodePtrList vir_outputs;
  73. std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
  74. auto fg_outputs = fg->output();
  75. if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) {
  76. auto cnode = fg_outputs->cast<CNodePtr>();
  77. for (size_t i = 1; i < cnode->size(); ++i) {
  78. vir_outputs.push_back(cnode->input(i));
  79. }
  80. } else {
  81. vir_outputs.push_back(fg_outputs);
  82. }
  83. if (vir_outputs.size() != outputs->size()) {
  84. MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
  85. }
  86. bool has_erase_outs = false;
  87. size_t index = -1;
  88. for (auto it = outputs->begin(); it != outputs->end();) {
  89. index++;
  90. auto out = *it;
  91. eqv[out] = vir_outputs[index];
  92. auto users = mng->node_users()[out];
  93. bool is_only_control_depend_use = true;
  94. std::vector<size_t> control_depend_use_index;
  95. std::vector<CNodePtr> control_depend_nodes;
  96. AnfNodePtr use_out = nullptr;
  97. for (auto &user : users) {
  98. auto use_node = user.first;
  99. if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
  100. is_only_control_depend_use = false;
  101. continue;
  102. }
  103. if (outputs_set.count(use_node) != 0) {
  104. use_out = use_node;
  105. }
  106. if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) {
  107. control_depend_nodes.push_back(use_node->cast<CNodePtr>());
  108. control_depend_use_index.push_back(user.second);
  109. }
  110. }
  111. if (is_only_control_depend_use && !control_depend_nodes.empty()) {
  112. MS_EXCEPTION_IF_NULL(use_out);
  113. it = outputs->erase(it);
  114. for (size_t i = 0; i < control_depend_nodes.size(); ++i) {
  115. auto control_depend_node = control_depend_nodes[i];
  116. std::vector<AnfNodePtr> new_control_depend_inputs;
  117. for (size_t j = 0; j < control_depend_node->size(); ++j) {
  118. if (j == control_depend_use_index[i]) {
  119. new_control_depend_inputs.push_back(use_out);
  120. } else {
  121. new_control_depend_inputs.push_back(control_depend_node->input(j));
  122. }
  123. }
  124. auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
  125. mng->Replace(control_depend_node, new_control_depend);
  126. has_erase_outs = true;
  127. }
  128. } else {
  129. it++;
  130. }
  131. }
  132. if (!has_erase_outs) {
  133. return;
  134. }
  135. AnfNodePtr fg_new_output;
  136. if (outputs->size() > 1) {
  137. std::vector<AnfNodePtr> output_args;
  138. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  139. (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args),
  140. [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
  141. // Set output for AnfGraph
  142. fg_new_output = fg->NewCNode(output_args);
  143. } else {
  144. fg_new_output = eqv[(*outputs)[0]];
  145. }
  146. fg->set_output(fg_new_output, true);
  147. }
  148. void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::vector<AnfNodePtr> &todos,
  149. std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) {
  150. auto mng = kernel_graph->manager();
  151. for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
  152. auto node = (*iter)->cast<CNodePtr>();
  153. if (node == nullptr) {
  154. continue;
  155. }
  156. if (fused_ops->count(node)) {
  157. continue;
  158. }
  159. auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select);
  160. bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
  161. [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
  162. if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
  163. continue;
  164. }
  165. auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
  166. if (fuse_nodes.size() <= 1) {
  167. continue;
  168. }
  169. FuncGraphPtr fg;
  170. AnfNodePtrList inputs;
  171. AnfNodePtrList outputs;
  172. std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
  173. RemoveControlDependOut(fg, &outputs, mng);
  174. auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select);
  175. ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
  176. // Set graph kernel attr
  177. std::string fuse_op_name = "";
  178. for (auto &fuse_node : fuse_nodes) {
  179. fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
  180. }
  181. fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
  182. fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
  183. }
  184. }
  185. } // namespace
  186. void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
  187. MS_EXCEPTION_IF_NULL(kernel_graph);
  188. auto mng = kernel_graph->manager();
  189. if (mng == nullptr) {
  190. mng = Manage(kernel_graph, true);
  191. kernel_graph->set_manager(mng);
  192. }
  193. std::unordered_set<AnfNodePtr> fused_ops;
  194. auto todos = TopoSort(kernel_graph->get_return());
  195. std::reverse(todos.begin(), todos.end());
  196. FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select);
  197. }
  198. } // namespace opt
  199. } // namespace mindspore