You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_dependence.cc 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pre_activate/pass/optimize_dependence.h"
  17. #include <memory>
  18. #include <vector>
  19. #include <string>
  20. #include "pre_activate/common/helper.h"
  21. #include "operator/ops.h"
  22. #include "utils/utils.h"
  23. #include "session/kernel_graph.h"
  24. #include "session/anf_runtime_algorithm.h"
  25. namespace mindspore {
  26. namespace opt {
  27. constexpr auto kSingleInputIndex = 1;
  28. namespace {
  29. AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
  30. MS_EXCEPTION_IF_NULL(node);
  31. if (!node->isa<CNode>()) {
  32. return nullptr;
  33. }
  34. auto cnode = node->cast<CNodePtr>();
  35. MS_EXCEPTION_IF_NULL(cnode);
  36. string op_name = AnfAlgo::GetCNodeName(cnode);
  37. // Currently we only eliminate transdata or cast nodes.
  38. if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
  39. return nullptr;
  40. }
  41. CheckCNodeInputSize(cnode, kSingleInputIndex + 1);
  42. return cnode->input(kSingleInputIndex);
  43. }
  44. bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
  45. MS_EXCEPTION_IF_NULL(func_graph);
  46. MS_EXCEPTION_IF_NULL(cnode);
  47. if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
  48. return false;
  49. }
  50. std::vector<AnfNodePtr> new_make_tuple_inputs;
  51. bool need_update = false;
  52. for (const auto &input : cnode->inputs()) {
  53. AnfNodePtr replace_input = GetReplaceNode(input);
  54. // If replace input is not null, it will be the input of the TransData or Cast.
  55. if (replace_input == nullptr) {
  56. new_make_tuple_inputs.push_back(input);
  57. continue;
  58. }
  59. new_make_tuple_inputs.push_back(replace_input);
  60. need_update = true;
  61. }
  62. if (need_update) {
  63. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  64. CNodePtr new_make_tuple = nullptr;
  65. if (kernel_graph == nullptr) {
  66. new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs);
  67. } else {
  68. new_make_tuple = kernel_graph->NewCNode(cnode);
  69. }
  70. MS_EXCEPTION_IF_NULL(new_make_tuple);
  71. new_make_tuple->set_inputs(new_make_tuple_inputs);
  72. auto manager = func_graph->manager();
  73. MS_EXCEPTION_IF_NULL(manager);
  74. manager->Replace(cnode, new_make_tuple);
  75. }
  76. return true;
  77. }
  78. } // namespace
  79. const BaseRef OptimizeDependence::DefinePattern() const {
  80. VarPtr X = std::make_shared<Var>("X");
  81. MS_EXCEPTION_IF_NULL(X);
  82. VarPtr Y = std::make_shared<Var>("Y");
  83. MS_EXCEPTION_IF_NULL(Y);
  84. return VectorRef({prim::kPrimDepend, X, Y});
  85. }
  86. const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  87. const EquivPtr &) const {
  88. MS_EXCEPTION_IF_NULL(func_graph);
  89. MS_EXCEPTION_IF_NULL(node);
  90. if (!node->isa<CNode>()) {
  91. return nullptr;
  92. }
  93. auto depend_cnode = node->cast<CNodePtr>();
  94. MS_EXCEPTION_IF_NULL(depend_cnode);
  95. CheckCNodeInputSize(depend_cnode, kDependInputNum);
  96. auto replacing_node = depend_cnode->input(kDependInputNum - 1);
  97. MS_EXCEPTION_IF_NULL(replacing_node);
  98. if (!replacing_node->isa<CNode>()) {
  99. return nullptr;
  100. }
  101. auto replacing_cnode = replacing_node->cast<CNodePtr>();
  102. MS_EXCEPTION_IF_NULL(replacing_cnode);
  103. // Deal with the make_tuple with TransData or Cast inputs.
  104. if (ReplaceMakeTuple(func_graph, replacing_cnode)) {
  105. return nullptr;
  106. }
  107. AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
  108. if (replace_node == nullptr) {
  109. MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
  110. return nullptr;
  111. }
  112. std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex),
  113. depend_cnode->input(kRealInputIndexInDepend), replace_node};
  114. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  115. CNodePtr new_depend;
  116. if (kernel_graph == nullptr) {
  117. new_depend = func_graph->NewCNode(new_depend_inputs);
  118. MS_EXCEPTION_IF_NULL(new_depend);
  119. new_depend->set_abstract(node->abstract());
  120. new_depend->set_scope(node->scope());
  121. } else {
  122. new_depend = kernel_graph->NewCNode(depend_cnode);
  123. MS_EXCEPTION_IF_NULL(new_depend);
  124. new_depend->set_inputs(new_depend_inputs);
  125. }
  126. return new_depend;
  127. }
  128. } // namespace opt
  129. } // namespace mindspore