You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_assign.cc 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/optimize_assign.h"
  17. #include <algorithm>
  18. #include <vector>
  19. #include <string>
  20. #include <map>
  21. #include "base/core_ops.h"
  22. #include "backend/common/optimizer/helper.h"
  23. #include "backend/common/session/anf_runtime_algorithm.h"
  24. #include "include/common/utils/anfalgo.h"
  25. #include "common/graph_kernel/graph_kernel_helper.h"
  26. namespace mindspore::graphkernel {
  27. namespace {
  28. /**
  29. * If an Assign's source node was outputted with this Assign, the src-node should be removed from output list,
  30. * external users can use the dest-node under the premise of correct execution order.
  31. * This function find out the [index of src node in output list] and [external dest-node].
  32. * Note:
  33. * 1. Assign is always in output list. (links to external Depend node)
  34. * 2. Assign's dest-node should be a Parameter.
  35. */
  36. std::map<size_t, AnfNodePtr> FindAssignAndOutputVal(const CNodePtr &fg_cnode) {
  37. // Check output includes assign
  38. auto func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(fg_cnode);
  39. MS_EXCEPTION_IF_NULL(func_graph);
  40. auto out_cnode = func_graph->output()->cast<CNodePtr>();
  41. MS_EXCEPTION_IF_NULL(out_cnode);
  42. std::map<size_t, AnfNodePtr> output_replace_map;
  43. if (!IsPrimitiveCNode(out_cnode, prim::kPrimMakeTuple)) {
  44. return output_replace_map;
  45. }
  46. // Trans parameter to the real input
  47. auto ParameterToInput = [&func_graph, &fg_cnode](const AnfNodePtr &p) {
  48. const auto &params = func_graph->parameters();
  49. size_t i = std::find(params.begin(), params.end(), p) - params.begin();
  50. return i == params.size() ? nullptr : fg_cnode->input(i + 1);
  51. };
  52. const auto &inputs = out_cnode->inputs();
  53. for (const auto &out : inputs) {
  54. if (IsPrimitiveCNode(out, prim::kPrimAssign)) {
  55. auto assign_val = out->cast<CNodePtr>()->input(2);
  56. auto assign_parameter = out->cast<CNodePtr>()->input(1);
  57. auto iter = std::find(inputs.begin() + 1, inputs.end(), assign_val);
  58. if (iter != inputs.end()) {
  59. size_t assign_val_index = iter - inputs.begin();
  60. auto assign_to = ParameterToInput(assign_parameter);
  61. if (assign_to != nullptr && assign_val_index > 0) {
  62. output_replace_map[assign_val_index - 1] = assign_to;
  63. }
  64. }
  65. }
  66. }
  67. return output_replace_map;
  68. }
  69. bool HasPathToParamUser(const AnfNodePtr &gk_node, const AnfNodePtr &param_user, const AnfNodePtr &getitem) {
  70. auto mng = common::AnfAlgo::GetCNodeFuncGraphPtr(gk_node)->manager();
  71. MS_EXCEPTION_IF_NULL(mng);
  72. bool result = false;
  73. auto IncludeUser = [&result, &gk_node, &getitem](const AnfNodePtr &node) {
  74. if (node == getitem) {
  75. return EXCLUDE;
  76. }
  77. if (node == gk_node) {
  78. result = true;
  79. return EXCLUDE;
  80. }
  81. return result ? EXCLUDE : FOLLOW;
  82. };
  83. static_cast<void>(DeepLinkedGraphSearch(param_user, IncludeUser));
  84. return result;
  85. }
  86. void KeepExecOrder(const FuncGraphPtr &func_graph, const AnfNodePtr &getitem, const AnfNodePtr &assign_to_node,
  87. const FuncGraphManagerPtr &mng) {
  88. // Insert update_state_node, need mount a monad node.
  89. auto u = NewValueNode(kUMonad);
  90. u->set_abstract(kUMonad->ToAbstract());
  91. AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, getitem};
  92. auto update_state_node = func_graph->NewCNode(update_state_inputs);
  93. update_state_node->set_abstract(getitem->abstract());
  94. func_graph->AddNode(update_state_node);
  95. // Insert load_node
  96. AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), assign_to_node, update_state_node};
  97. auto load_node = func_graph->NewCNode(load_inputs);
  98. load_node->set_abstract(assign_to_node->abstract());
  99. func_graph->AddNode(load_node);
  100. (void)mng->Replace(getitem, load_node);
  101. }
  102. int64_t GetitemIndex(const AnfNodePtr &getitem) {
  103. auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
  104. auto value_ptr = GetValueNode(index_node);
  105. return GetValue<int64_t>(value_ptr);
  106. }
  107. void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode, const AnfNodePtr &assign_to,
  108. int64_t removed_index) {
  109. auto mng = func_graph->manager();
  110. MS_EXCEPTION_IF_NULL(mng);
  111. for (const auto &getitem_iter : mng->node_users()[cnode]) {
  112. auto getitem = getitem_iter.first;
  113. if (GetitemIndex(getitem) != removed_index) continue;
  114. auto getitem_users = mng->node_users()[getitem]; // get a copy of getitem's users before replacing
  115. for (const auto &getitem_user_iter : getitem_users) {
  116. auto getitem_user = getitem_user_iter.first;
  117. // 1. Data users may not link directly to its input, they may segregated by Depend node.
  118. // 2. If the `cnode` has another path to the getitem_user, it's unnecessary to add update_state and load node to
  119. // keep exec_order.
  120. if (HasPathToParamUser(cnode, getitem_user, getitem)) {
  121. (void)mng->Replace(getitem, assign_to);
  122. continue;
  123. }
  124. KeepExecOrder(func_graph, getitem, assign_to, mng);
  125. }
  126. break;
  127. }
  128. }
  129. bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
  130. auto todos = TopoSort(func_graph->get_return());
  131. MS_EXCEPTION_IF_NULL(func_graph);
  132. bool changed = false;
  133. for (const auto &n : todos) {
  134. if (!common::AnfAlgo::IsGraphKernel(n)) continue;
  135. auto cnode = n->cast<CNodePtr>();
  136. auto replaceable_nodes = FindAssignAndOutputVal(cnode);
  137. if (replaceable_nodes.empty()) continue;
  138. changed = true;
  139. for (const auto &iter : replaceable_nodes) {
  140. UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, static_cast<int64_t>(iter.first));
  141. }
  142. }
  143. return changed;
  144. }
  145. bool ReplaceAssignByInplaceAssignInGraphkernel(const FuncGraphPtr &func_graph) {
  146. auto mng = func_graph->manager();
  147. MS_EXCEPTION_IF_NULL(mng);
  148. auto todos = TopoSort(func_graph->get_return());
  149. bool changed = false;
  150. for (const auto &n : todos) {
  151. if (!common::AnfAlgo::CheckPrimitiveType(n, prim::kPrimAssign)) continue;
  152. changed = true;
  153. auto cnode = n->cast<CNodePtr>();
  154. AnfNodePtrList inputs = {NewValueNode(prim::kPrimInplaceAssign), cnode->input(1), cnode->input(2), cnode->input(2)};
  155. auto new_cnode = func_graph->NewCNode(inputs);
  156. SetNodeAttrSafely("fake_output", MakeValue(true), new_cnode);
  157. new_cnode->set_abstract(inputs.back()->abstract());
  158. new_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
  159. std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(cnode);
  160. std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
  161. input_formats.push_back(input_formats.back());
  162. input_types.push_back(input_types.back());
  163. std::vector<std::string> output_formats = {input_formats.back()};
  164. std::vector<TypeId> output_types = {input_types.back()};
  165. auto graph_sel_info = BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types,
  166. AnfAlgo::GetProcessor(cnode));
  167. AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, new_cnode.get());
  168. (void)mng->Replace(cnode, new_cnode);
  169. }
  170. return changed;
  171. }
  172. bool RepalceAssignByInplaceAssign(const FuncGraphPtr &func_graph) {
  173. MS_EXCEPTION_IF_NULL(func_graph);
  174. auto mng = func_graph->manager();
  175. MS_EXCEPTION_IF_NULL(mng);
  176. auto todos = TopoSort(func_graph->get_return());
  177. auto changed = false;
  178. for (const auto &n : todos) {
  179. if (!common::AnfAlgo::IsGraphKernel(n)) continue;
  180. auto graph_kernel_fg = common::AnfAlgo::GetCNodeFuncGraphPtr(n);
  181. MS_EXCEPTION_IF_NULL(graph_kernel_fg);
  182. changed = ReplaceAssignByInplaceAssignInGraphkernel(graph_kernel_fg) || changed;
  183. }
  184. return changed;
  185. }
  186. } // namespace
  187. bool OptimizeAssign::Run(const FuncGraphPtr &func_graph) {
  188. auto mng = func_graph->manager();
  189. if (mng == nullptr) {
  190. mng = Manage(func_graph, true);
  191. func_graph->set_manager(mng);
  192. }
  193. auto res = RepalceOutputByParameter(func_graph);
  194. if (res) {
  195. mng->RemoveRoots();
  196. mng->KeepRoots({func_graph});
  197. }
  198. return RepalceAssignByInplaceAssign(func_graph);
  199. }
  200. } // namespace mindspore::graphkernel