You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

while_pass.cc 5.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/optimizer/graph/while_pass.h"
  17. #include <vector>
  18. #include <memory>
  19. #include "ops/switch.h"
  20. #include "include/errorcode.h"
  21. #include "tools/optimizer/common/gllo_utils.h"
  22. #include "src/common/log_adapter.h"
  23. namespace mindspore::opt {
  24. ValueNodePtr WhilePass::GetSwitchAnfPrim() {
  25. auto switch_prim = std::make_shared<mindspore::ops::Switch>();
  26. ValueNodePtr partial_anf_prim = NewValueNode(switch_prim);
  27. return partial_anf_prim;
  28. }
  29. bool WhilePass::Run(const FuncGraphPtr &graph) {
  30. auto node_list = TopoSort(graph->get_return());
  31. static int count = 0;
  32. for (auto &node : node_list) {
  33. if (!utils::isa<CNodePtr>(node)) {
  34. continue;
  35. }
  36. if (!CheckPrimitiveType(node, prim::kPrimWhile)) {
  37. continue;
  38. }
  39. auto while_cnode = node->cast<CNodePtr>();
  40. MS_ASSERT(while_cnode != nullptr);
  41. if (while_cnode->inputs().size() < kWhileMinInputSize) {
  42. MS_LOG(ERROR) << "while input is not right.";
  43. return false;
  44. }
  45. // the order is fixed.
  46. auto cond_vnode = while_cnode->input(kWhileCondIndex);
  47. auto body_vnode = while_cnode->input(kWhileBodyIndex);
  48. auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
  49. auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
  50. if (cond_fg == nullptr || body_fg == nullptr) {
  51. MS_LOG(ERROR) << "Get value as func_graph failed.";
  52. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED);
  53. return false;
  54. }
  55. std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode};
  56. std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode};
  57. cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
  58. while_cnode->inputs().end());
  59. body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
  60. while_cnode->inputs().end());
  61. static int idx = 0;
  62. auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs);
  63. cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx));
  64. cond_partial_node->set_abstract(cond_fg->output()->abstract());
  65. auto body_partial_node = graph->NewCNode(body_partial_op_inputs);
  66. body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx));
  67. idx++;
  68. // concat body_fg output to cond_fg input
  69. auto body_output = body_fg->output();
  70. auto body_output_cnode = utils::cast<CNodePtr>(body_output);
  71. auto prim = GetValueNode<PrimitiveCPtr>(body_output_cnode->input(0));
  72. if (prim == nullptr) {
  73. MS_LOG(ERROR) << "Get PrimitiveC of node:" << body_output_cnode->fullname_with_scope() << " failed.";
  74. return false;
  75. }
  76. // concat body to cond
  77. std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode};
  78. if (CheckPrimitiveType(body_output_cnode, kPrimMakeTuple)) {
  79. for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) {
  80. body_to_cond_inputs.emplace_back(body_output_cnode->input(i));
  81. }
  82. } else {
  83. body_to_cond_inputs.emplace_back(body_output_cnode);
  84. }
  85. // concat body to cond
  86. auto body_to_cond_cnode = body_fg->NewCNode(body_to_cond_inputs);
  87. body_to_cond_cnode->set_fullname_with_scope("Partial-while-body-to-cond");
  88. auto body_fg_manager = body_fg->manager();
  89. body_fg_manager->Replace(body_fg->output(), body_to_cond_cnode);
  90. body_fg->set_output(body_to_cond_cnode);
  91. body_partial_node->set_abstract(cond_fg->output()->abstract());
  92. // create switch cnode
  93. ValueNodePtr switch_anf_primitive = GetSwitchAnfPrim();
  94. if (switch_anf_primitive == nullptr) {
  95. MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
  96. return false;
  97. }
  98. // insert switch node
  99. std::vector<AnfNodePtr> switch_op_inputs = {switch_anf_primitive, cond_partial_node, body_partial_node};
  100. auto switch_cnode = graph->NewCNode(switch_op_inputs);
  101. switch_cnode->set_fullname_with_scope("Switch-" + std::to_string(count++));
  102. AbstractBasePtrList abstract_list;
  103. auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output());
  104. for (auto &cnode : body_fg_output_cnode->inputs()) {
  105. if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) {
  106. continue;
  107. }
  108. abstract_list.push_back(cnode->abstract());
  109. }
  110. switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
  111. // create cond partial cnode
  112. auto manager = graph->manager();
  113. if (!manager->Replace(while_cnode, switch_cnode)) {
  114. MS_LOG(ERROR) << "replace node failed.";
  115. return false;
  116. }
  117. }
  118. return true;
  119. }
  120. } // namespace mindspore::opt