You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_dependence.cc 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/pass/optimize_dependence.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <string>
  20. #include "backend/optimizer/common/helper.h"
  21. #include "base/core_ops.h"
  22. #include "utils/utils.h"
  23. #include "backend/session/kernel_graph.h"
  24. #include "backend/session/anf_runtime_algorithm.h"
  25. namespace mindspore {
  26. namespace opt {
  27. constexpr auto kSingleInputIndex = 1;
  28. namespace {
  29. AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  30. MS_EXCEPTION_IF_NULL(node);
  31. if (!node->isa<CNode>()) {
  32. return nullptr;
  33. }
  34. auto cnode = node->cast<CNodePtr>();
  35. MS_EXCEPTION_IF_NULL(cnode);
  36. string op_name = AnfAlgo::GetCNodeName(cnode);
  37. // Currently we only eliminate transdata or cast nodes.
  38. if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
  39. return nullptr;
  40. }
  41. if (!IsNotRealUsedByOthers(func_graph, cnode)) {
  42. return nullptr;
  43. }
  44. CheckCNodeInputSize(cnode, kSingleInputIndex + 1);
  45. return cnode->input(kSingleInputIndex);
  46. }
  47. AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
  48. MS_EXCEPTION_IF_NULL(func_graph);
  49. MS_EXCEPTION_IF_NULL(cnode);
  50. if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
  51. return nullptr;
  52. }
  53. std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
  54. bool need_update = false;
  55. for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
  56. auto input = AnfAlgo::GetInputNode(cnode, index);
  57. AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
  58. // If replace input is not null, it will be the input of the TransData or Cast.
  59. if (replace_input == nullptr) {
  60. new_make_tuple_inputs.push_back(input);
  61. continue;
  62. }
  63. new_make_tuple_inputs.push_back(replace_input);
  64. need_update = true;
  65. }
  66. if (need_update) {
  67. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  68. CNodePtr new_make_tuple = nullptr;
  69. if (kernel_graph == nullptr) {
  70. new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs);
  71. } else {
  72. new_make_tuple = kernel_graph->NewCNode(cnode);
  73. }
  74. MS_EXCEPTION_IF_NULL(new_make_tuple);
  75. new_make_tuple->set_inputs(new_make_tuple_inputs);
  76. auto manager = func_graph->manager();
  77. MS_EXCEPTION_IF_NULL(manager);
  78. manager->Replace(cnode, new_make_tuple);
  79. return new_make_tuple;
  80. }
  81. return nullptr;
  82. }
  83. } // namespace
  84. const BaseRef OptimizeDependence::DefinePattern() const {
  85. VarPtr X = std::make_shared<Var>();
  86. VarPtr Xs = std::make_shared<SeqVar>();
  87. return VectorRef({X, Xs});
  88. }
  89. const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  90. const EquivPtr &) const {
  91. MS_EXCEPTION_IF_NULL(func_graph);
  92. MS_EXCEPTION_IF_NULL(node);
  93. if (!node->isa<CNode>()) {
  94. return nullptr;
  95. }
  96. auto node_name = AnfAlgo::GetCNodeName(node);
  97. if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) {
  98. return nullptr;
  99. }
  100. size_t index = 0;
  101. auto depend_cnode = node->cast<CNodePtr>();
  102. MS_EXCEPTION_IF_NULL(depend_cnode);
  103. std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)};
  104. if (node_name == prim::kPrimDepend->name()) {
  105. index = 1;
  106. new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend));
  107. }
  108. if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) {
  109. MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got "
  110. << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString();
  111. }
  112. auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode);
  113. while (index < input_num) {
  114. auto replace_node = GetConvertNode(func_graph, node, index);
  115. MS_EXCEPTION_IF_NULL(replace_node);
  116. new_depend_inputs.push_back(replace_node);
  117. ++index;
  118. }
  119. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  120. CNodePtr new_depend = nullptr;
  121. if (kernel_graph == nullptr) {
  122. new_depend = func_graph->NewCNode(new_depend_inputs);
  123. MS_EXCEPTION_IF_NULL(new_depend);
  124. new_depend->set_abstract(node->abstract());
  125. new_depend->set_scope(node->scope());
  126. } else {
  127. new_depend = kernel_graph->NewCNode(depend_cnode);
  128. MS_EXCEPTION_IF_NULL(new_depend);
  129. new_depend->set_inputs(new_depend_inputs);
  130. }
  131. return new_depend;
  132. }
  133. const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
  134. const size_t index) const {
  135. MS_EXCEPTION_IF_NULL(graph);
  136. MS_EXCEPTION_IF_NULL(node);
  137. auto depend_cnode = node->cast<CNodePtr>();
  138. auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
  139. MS_EXCEPTION_IF_NULL(replacing_node);
  140. if (!replacing_node->isa<CNode>()) {
  141. return replacing_node;
  142. }
  143. auto replacing_cnode = replacing_node->cast<CNodePtr>();
  144. MS_EXCEPTION_IF_NULL(replacing_cnode);
  145. // Deal with the make_tuple with TransData or Cast inputs.
  146. auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
  147. if (make_tuple_replace_node != nullptr) {
  148. return make_tuple_replace_node;
  149. }
  150. AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode);
  151. if (replace_node == nullptr) {
  152. MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
  153. return replacing_node;
  154. }
  155. return replace_node;
  156. }
  157. } // namespace opt
  158. } // namespace mindspore