You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_dependence.cc 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/pass/optimize_dependence.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <string>
  20. #include "backend/optimizer/common/helper.h"
  21. #include "base/core_ops.h"
  22. #include "utils/utils.h"
  23. #include "backend/session/kernel_graph.h"
  24. #include "backend/session/anf_runtime_algorithm.h"
  25. namespace mindspore {
  26. namespace opt {
  27. constexpr auto kSingleInputIndex = 1;
  28. namespace {
  29. AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  30. MS_EXCEPTION_IF_NULL(node);
  31. if (!node->isa<CNode>()) {
  32. return nullptr;
  33. }
  34. auto cnode = node->cast<CNodePtr>();
  35. MS_EXCEPTION_IF_NULL(cnode);
  36. string op_name = AnfAlgo::GetCNodeName(cnode);
  37. // Currently we only eliminate transdata or cast nodes.
  38. if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
  39. return nullptr;
  40. }
  41. if (!IsNotRealUsedByOthers(func_graph, cnode)) {
  42. return nullptr;
  43. }
  44. CheckCNodeInputSize(cnode, kSingleInputIndex);
  45. return cnode->input(kSingleInputIndex);
  46. }
  47. AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
  48. MS_EXCEPTION_IF_NULL(func_graph);
  49. MS_EXCEPTION_IF_NULL(cnode);
  50. if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
  51. return nullptr;
  52. }
  53. std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
  54. bool need_update = false;
  55. size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
  56. for (size_t index = 0; index < input_num; ++index) {
  57. auto input = AnfAlgo::GetInputNode(cnode, index);
  58. AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
  59. // If replace input is not null, it will be the input of the TransData or Cast.
  60. if (replace_input == nullptr) {
  61. new_make_tuple_inputs.push_back(input);
  62. continue;
  63. }
  64. new_make_tuple_inputs.push_back(replace_input);
  65. need_update = true;
  66. }
  67. if (need_update) {
  68. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  69. CNodePtr new_make_tuple = nullptr;
  70. if (kernel_graph == nullptr) {
  71. new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs);
  72. } else {
  73. new_make_tuple = kernel_graph->NewCNode(cnode);
  74. }
  75. MS_EXCEPTION_IF_NULL(new_make_tuple);
  76. new_make_tuple->set_inputs(new_make_tuple_inputs);
  77. auto manager = func_graph->manager();
  78. MS_EXCEPTION_IF_NULL(manager);
  79. manager->Replace(cnode, new_make_tuple);
  80. return new_make_tuple;
  81. }
  82. return nullptr;
  83. }
  84. } // namespace
  85. const BaseRef OptimizeDependence::DefinePattern() const {
  86. VarPtr X = std::make_shared<Var>();
  87. VarPtr Xs = std::make_shared<SeqVar>();
  88. return VectorRef({X, Xs});
  89. }
  90. std::pair<AnfNodePtr, size_t> SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) {
  91. if (node == nullptr || !node->isa<CNode>()) {
  92. return std::pair<AnfNodePtr, size_t>(nullptr, 0);
  93. }
  94. // get real input of depend and update state.
  95. size_t replace_input_index = 0;
  96. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  97. replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend;
  98. } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
  99. replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput;
  100. } else {
  101. return std::pair<AnfNodePtr, size_t>(nullptr, 0);
  102. }
  103. // check whether real input is cast or trans data
  104. auto real_input = node->cast<CNodePtr>()->input(replace_input_index);
  105. if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) ||
  106. AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) ||
  107. AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) {
  108. return std::pair<AnfNodePtr, size_t>(node, replace_input_index);
  109. }
  110. return SearchTransDataAndCast(real_input, false);
  111. }
  112. const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  113. const EquivPtr &) const {
  114. MS_EXCEPTION_IF_NULL(func_graph);
  115. MS_EXCEPTION_IF_NULL(node);
  116. if (!node->isa<CNode>()) {
  117. return nullptr;
  118. }
  119. // Get the cnode with repalce input index
  120. auto cnode_with_input_index = SearchTransDataAndCast(node, true);
  121. if (cnode_with_input_index.first == nullptr) {
  122. return nullptr;
  123. }
  124. size_t replace_index = cnode_with_input_index.second;
  125. auto depend_cnode = cnode_with_input_index.first->cast<CNodePtr>();
  126. MS_EXCEPTION_IF_NULL(depend_cnode);
  127. // Get new node which will act as new input of depend or UpdateState.
  128. std::vector<AnfNodePtr> new_depend_inputs = depend_cnode->inputs();
  129. auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index);
  130. if (replace_node == nullptr) {
  131. return nullptr;
  132. }
  133. new_depend_inputs[replace_index] = replace_node;
  134. // Because depend's input has been changed, so a new depend(UpdateState) node will be created to replaced the old one.
  135. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  136. CNodePtr new_depend = nullptr;
  137. if (kernel_graph == nullptr) {
  138. new_depend = func_graph->NewCNode(new_depend_inputs);
  139. MS_EXCEPTION_IF_NULL(new_depend);
  140. new_depend->set_abstract(depend_cnode->abstract());
  141. new_depend->set_scope(depend_cnode->scope());
  142. } else {
  143. new_depend = kernel_graph->NewCNode(depend_cnode);
  144. MS_EXCEPTION_IF_NULL(new_depend);
  145. new_depend->set_inputs(new_depend_inputs);
  146. }
  147. func_graph->manager()->Replace(depend_cnode, new_depend);
  148. return nullptr;
  149. }
  150. const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
  151. const size_t index) const {
  152. MS_EXCEPTION_IF_NULL(graph);
  153. MS_EXCEPTION_IF_NULL(node);
  154. auto depend_cnode = node->cast<CNodePtr>();
  155. auto replacing_node = depend_cnode->input(index);
  156. MS_EXCEPTION_IF_NULL(replacing_node);
  157. if (!replacing_node->isa<CNode>()) {
  158. return nullptr;
  159. }
  160. auto replacing_cnode = replacing_node->cast<CNodePtr>();
  161. MS_EXCEPTION_IF_NULL(replacing_cnode);
  162. // Deal with the make_tuple with TransData or Cast inputs.
  163. auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
  164. if (make_tuple_replace_node != nullptr) {
  165. return make_tuple_replace_node;
  166. }
  167. AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode);
  168. return replace_node;
  169. }
  170. } // namespace opt
  171. } // namespace mindspore