You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functionalize_cond.cc 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/optimizer/graph/functionalize_cond.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <deque>
  20. #include <unordered_set>
  21. #include "include/errorcode.h"
  22. #include "tools/converter/ops/ops_def.h"
  23. namespace mindspore::opt {
  24. STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) {
  25. MS_ASSERT(switch_cnode != nullptr);
  26. MS_ASSERT(branch_type != nullptr);
  27. auto manager = fg_->manager();
  28. if (manager == nullptr) {
  29. MS_LOG(ERROR) << "manager is nullptr";
  30. return RET_ERROR;
  31. }
  32. auto node_users = manager->node_users()[switch_cnode];
  33. if (node_users.size() != 1) { // only one output of switch is referenced in cond
  34. MS_LOG(ERROR) << "switch's node users is not correct";
  35. return RET_ERROR;
  36. }
  37. auto node_user = node_users.front();
  38. auto tuple_get_item = node_user.first;
  39. if (!utils::isa<CNodePtr>(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
  40. MS_LOG(ERROR) << "switch's node user is not TupleGetItem";
  41. return RET_ERROR;
  42. }
  43. auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item);
  44. auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode);
  45. if (idx == 0) {
  46. *branch_type = kElseBranch;
  47. } else if (idx == 1) {
  48. *branch_type = kThenBranch;
  49. } else {
  50. MS_LOG(ERROR) << "wrong tuple_get_item index";
  51. return RET_ERROR;
  52. }
  53. return RET_OK;
  54. }
  55. STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node,
  56. BranchType branch_type) {
  57. std::deque<AnfNodePtr> q;
  58. std::unordered_set<AnfNodePtr> vis;
  59. q.push_back(root_node);
  60. while (!q.empty()) {
  61. auto node = q.front();
  62. q.pop_front();
  63. vis.insert(node);
  64. if (FunctionalizeControlOpPass::IsSwitch(node)) {
  65. auto cnode = utils::cast<CNodePtr>(node);
  66. BranchType this_type;
  67. if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) {
  68. MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct";
  69. return RET_ERROR;
  70. }
  71. continue;
  72. }
  73. if (utils::isa<ParameterPtr>(node)) {
  74. graph->add_parameter(node->cast<ParameterPtr>());
  75. } else {
  76. graph->AddNode(node);
  77. }
  78. node->set_func_graph(graph);
  79. if (utils::isa<CNodePtr>(node)) {
  80. auto cnode = utils::cast<CNodePtr>(node);
  81. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  82. auto inputi = cnode->input(i);
  83. if (vis.find(inputi) == vis.end()) {
  84. q.push_back(cnode->input(i));
  85. }
  86. }
  87. }
  88. }
  89. return RET_OK;
  90. }
  91. int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) {
  92. auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node);
  93. if (index == input_nodes_.end()) {
  94. input_nodes_.push_back(node);
  95. return input_nodes_.size() - 1;
  96. }
  97. return index - input_nodes_.begin();
  98. }
  99. STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) {
  100. std::vector<AnfNodePtr> nodes_need_drop{};
  101. for (auto &cnode : graph->GetOrderedCnodes()) {
  102. for (auto &input_node : cnode->inputs()) {
  103. if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
  104. auto switch_node = input_node->cast<CNodePtr>();
  105. auto switch_input = utils::cast<CNodePtr>(switch_node->input(1));
  106. auto pos = PosInInputNodes(switch_input);
  107. nodes_need_drop.push_back(cnode);
  108. pred_nodes_.push_back(switch_node->input(2));
  109. // set parameter
  110. auto parameter = graph->add_parameter();
  111. parameter->set_abstract(cnode->abstract());
  112. // hardcode for subgraph input name
  113. parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter");
  114. // replace switch
  115. auto manager = fg_->manager();
  116. auto node_users = manager->node_users()[cnode];
  117. for (auto &node_user : node_users) {
  118. if (graph->nodes().contains(node_user.first)) {
  119. manager->SetEdge(node_user.first, node_user.second, parameter);
  120. }
  121. }
  122. }
  123. }
  124. }
  125. return RET_OK;
  126. }
  127. FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
  128. auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, mindspore::lite::converter::FmkType_TF);
  129. if (graph == nullptr) {
  130. MS_LOG(ERROR) << "new graph Partial Node return nullptr";
  131. return nullptr;
  132. }
  133. graph->set_manager(fg_->manager());
  134. auto status = BranchSubGraphAddNodes(graph, node, branch_type);
  135. if (status != RET_OK) {
  136. return nullptr;
  137. }
  138. if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
  139. auto return_prim_ptr = std::make_shared<lite::Return>();
  140. if (return_prim_ptr == nullptr) {
  141. MS_LOG(ERROR) << "GetReturnPrim return nullptr";
  142. return nullptr;
  143. }
  144. auto value_node = NewValueNode(return_prim_ptr);
  145. std::vector<AnfNodePtr> op_inputs{value_node, node}; // If subgraph only has one output tensor
  146. auto return_cnode = graph->NewCNode(op_inputs);
  147. return_cnode->set_fullname_with_scope(name + "-return");
  148. return_cnode->set_func_graph(graph);
  149. graph->set_return(return_cnode);
  150. graph->output()->cast<CNodePtr>()->set_fullname_with_scope(name + "_output_0_cnode");
  151. }
  152. return graph;
  153. }
  154. CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) {
  155. MS_ASSERT(else_branch != nullptr);
  156. MS_ASSERT(then_branch != nullptr);
  157. auto if_primc = std::make_shared<mindspore::lite::If>();
  158. if (if_primc == nullptr) {
  159. MS_LOG(ERROR) << "new if_primitive failed";
  160. return nullptr;
  161. }
  162. auto if_value_node = NewValueNode(if_primc);
  163. if (if_value_node == nullptr) {
  164. return nullptr;
  165. }
  166. auto then_value_node = NewValueNode(then_branch);
  167. auto else_value_node = NewValueNode(else_branch);
  168. std::vector<AnfNodePtr> if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_};
  169. std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs));
  170. return fg_->NewCNode(if_op_inputs);
  171. }
  172. STATUS FunctionalizeCond::VerifyPredictNode() {
  173. if (pred_nodes_.empty()) {
  174. return RET_ERROR;
  175. }
  176. for (size_t i = 1; i < pred_nodes_.size(); ++i) {
  177. if (pred_nodes_[i] != pred_nodes_[0]) {
  178. return RET_ERROR;
  179. }
  180. }
  181. if (!utils::isa<CNodePtr>(pred_nodes_[0])) {
  182. return RET_ERROR;
  183. }
  184. pred_node_ = utils::cast<CNodePtr>(pred_nodes_[0]);
  185. return RET_OK;
  186. }
  187. STATUS FunctionalizeCond::Process() {
  188. if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->inputs().size() != 3) {
  189. MS_LOG(ERROR) << "fg or merge is not correct";
  190. return RET_ERROR;
  191. }
  192. auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else";
  193. auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else";
  194. auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch);
  195. if (else_branch == nullptr) {
  196. MS_LOG(ERROR) << "create else branch failed";
  197. return RET_ERROR;
  198. }
  199. auto then_branch = CreateBranchGraph(merge_node_->input(2), then_branch_name, kThenBranch);
  200. if (then_branch == nullptr) {
  201. MS_LOG(ERROR) << "create then branch failed";
  202. return RET_ERROR;
  203. }
  204. auto status = IdentifySubgraphInput(else_branch, else_branch_name);
  205. if (status != RET_OK) {
  206. return status;
  207. }
  208. status = IdentifySubgraphInput(then_branch, then_branch_name);
  209. if (status != RET_OK) {
  210. return status;
  211. }
  212. status = VerifyPredictNode();
  213. if (status != RET_OK) {
  214. return status;
  215. }
  216. auto if_node = CreateNewIf(else_branch, then_branch);
  217. if (if_node == nullptr) {
  218. MS_LOG(ERROR) << "create if node error";
  219. return RET_ERROR;
  220. }
  221. if_node->set_abstract(merge_node_->abstract()->Clone());
  222. auto manager = fg_->manager();
  223. auto node_users = manager->node_users()[merge_node_];
  224. for (auto &node_user : node_users) {
  225. manager->SetEdge(node_user.first, node_user.second, if_node);
  226. }
  227. return RET_OK;
  228. }
  229. } // namespace mindspore::opt