You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_dependence.cc 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/pass/optimize_dependence.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <string>
  20. #include "backend/optimizer/common/helper.h"
  21. #include "frontend/operator/ops.h"
  22. #include "utils/utils.h"
  23. #include "backend/session/kernel_graph.h"
  24. #include "backend/session/anf_runtime_algorithm.h"
  25. namespace mindspore {
  26. namespace opt {
  27. constexpr auto kSingleInputIndex = 1;
  28. namespace {
  29. AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
  30. MS_EXCEPTION_IF_NULL(node);
  31. if (!node->isa<CNode>()) {
  32. return nullptr;
  33. }
  34. auto cnode = node->cast<CNodePtr>();
  35. MS_EXCEPTION_IF_NULL(cnode);
  36. string op_name = AnfAlgo::GetCNodeName(cnode);
  37. // Currently we only eliminate transdata or cast nodes.
  38. if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
  39. return nullptr;
  40. }
  41. CheckCNodeInputSize(cnode, kSingleInputIndex + 1);
  42. return cnode->input(kSingleInputIndex);
  43. }
  44. AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
  45. MS_EXCEPTION_IF_NULL(func_graph);
  46. MS_EXCEPTION_IF_NULL(cnode);
  47. if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
  48. return nullptr;
  49. }
  50. std::vector<AnfNodePtr> new_make_tuple_inputs;
  51. bool need_update = false;
  52. for (const auto &input : cnode->inputs()) {
  53. AnfNodePtr replace_input = GetReplaceNode(input);
  54. // If replace input is not null, it will be the input of the TransData or Cast.
  55. if (replace_input == nullptr) {
  56. new_make_tuple_inputs.push_back(input);
  57. continue;
  58. }
  59. new_make_tuple_inputs.push_back(replace_input);
  60. need_update = true;
  61. }
  62. if (need_update) {
  63. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  64. CNodePtr new_make_tuple = nullptr;
  65. if (kernel_graph == nullptr) {
  66. new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs);
  67. } else {
  68. new_make_tuple = kernel_graph->NewCNode(cnode);
  69. }
  70. MS_EXCEPTION_IF_NULL(new_make_tuple);
  71. new_make_tuple->set_inputs(new_make_tuple_inputs);
  72. auto manager = func_graph->manager();
  73. MS_EXCEPTION_IF_NULL(manager);
  74. manager->Replace(cnode, new_make_tuple);
  75. return new_make_tuple;
  76. }
  77. return nullptr;
  78. }
  79. } // namespace
  80. const BaseRef OptimizeDependence::DefinePattern() const {
  81. VarPtr X = std::make_shared<Var>();
  82. VarPtr Xs = std::make_shared<SeqVar>();
  83. return VectorRef({X, Xs});
  84. }
  85. const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  86. const EquivPtr &) const {
  87. MS_EXCEPTION_IF_NULL(func_graph);
  88. MS_EXCEPTION_IF_NULL(node);
  89. if (!node->isa<CNode>()) {
  90. return nullptr;
  91. }
  92. auto node_name = AnfAlgo::GetCNodeName(node);
  93. if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) {
  94. return nullptr;
  95. }
  96. size_t index = 0;
  97. auto depend_cnode = node->cast<CNodePtr>();
  98. MS_EXCEPTION_IF_NULL(depend_cnode);
  99. std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)};
  100. if (node_name == prim::kPrimDepend->name()) {
  101. index = 1;
  102. new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend));
  103. }
  104. if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) {
  105. MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got "
  106. << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString();
  107. }
  108. auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode);
  109. while (index < input_num) {
  110. auto replace_node = GetConvertNode(func_graph, node, index);
  111. MS_EXCEPTION_IF_NULL(replace_node);
  112. new_depend_inputs.push_back(replace_node);
  113. ++index;
  114. }
  115. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  116. CNodePtr new_depend = nullptr;
  117. if (kernel_graph == nullptr) {
  118. new_depend = func_graph->NewCNode(new_depend_inputs);
  119. MS_EXCEPTION_IF_NULL(new_depend);
  120. new_depend->set_abstract(node->abstract());
  121. new_depend->set_scope(node->scope());
  122. } else {
  123. new_depend = kernel_graph->NewCNode(depend_cnode);
  124. MS_EXCEPTION_IF_NULL(new_depend);
  125. new_depend->set_inputs(new_depend_inputs);
  126. }
  127. return new_depend;
  128. }
  129. const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
  130. const size_t index) const {
  131. MS_EXCEPTION_IF_NULL(graph);
  132. MS_EXCEPTION_IF_NULL(node);
  133. auto depend_cnode = node->cast<CNodePtr>();
  134. auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
  135. MS_EXCEPTION_IF_NULL(replacing_node);
  136. if (!replacing_node->isa<CNode>()) {
  137. return replacing_node;
  138. }
  139. auto replacing_cnode = replacing_node->cast<CNodePtr>();
  140. MS_EXCEPTION_IF_NULL(replacing_cnode);
  141. // Deal with the make_tuple with TransData or Cast inputs.
  142. auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
  143. if (make_tuple_replace_node != nullptr) {
  144. return make_tuple_replace_node;
  145. }
  146. AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
  147. if (replace_node == nullptr) {
  148. MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
  149. return replacing_node;
  150. }
  151. return replace_node;
  152. }
  153. } // namespace opt
  154. } // namespace mindspore