diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc index 3e5c671ffd..527d7cfe0c 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc @@ -48,6 +48,9 @@ std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) { for (auto &node : func_graph->nodes()) { + if (!utils::isa(node)) { + continue; + } auto cluster_name = NodeClusterName(node); auto cluster_pos = WhichCluster(cluster_name); if (cluster_pos == node_clusters_.size()) { @@ -90,6 +93,7 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret; return ret; } + break; } } } diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_while.cc b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc index 29c5cf39e4..43fd259195 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_while.cc +++ b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc @@ -196,17 +196,14 @@ STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() { } STATUS FunctionalizeWhile::UpdateExitNodeUser() { + auto manager = fg_->manager(); if (output_exit_nodes_.size() == 1) { - auto manager = fg_->manager(); - auto node_users = manager->node_users()[output_exit_nodes_[0]]; - for (auto &node_user : node_users) { - if (fg_->nodes().contains(node_user.first)) { - manager->SetEdge(node_user.first, node_user.second, while_node_); - } + if (!manager->Replace(output_exit_nodes_[0], while_node_)) { + MS_LOG(ERROR) << "replace node failed."; + return RET_ERROR; } } else { for (auto &node : output_exit_nodes_) { - auto manager = fg_->manager(); auto node_users = manager->node_users()[node]; for (auto &node_user : node_users) { // new getitem @@ -237,7 +234,10 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() { get_item_node->set_fullname_with_scope(output_item_name); // set if (fg_->nodes().contains(node_user.first)) { - manager->SetEdge(node_user.first, node_user.second, get_item_node); + if (!manager->Replace(output_exit_nodes_[0], while_node_)) { + MS_LOG(ERROR) << "replace node failed."; + return RET_ERROR; + } } } } diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.cc b/mindspore/lite/tools/optimizer/graph/while_pass.cc index bd7f84dfd4..93ac70a3bb 100644 --- a/mindspore/lite/tools/optimizer/graph/while_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/while_pass.cc @@ -86,7 +86,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); } } else { - body_to_cond_inputs.emplace_back(body_output_cnode->input(1)); + body_to_cond_inputs.emplace_back(body_output_cnode); } // concat body to cond @@ -121,9 +121,9 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { // create cond partial cnode auto manager = graph->manager(); - auto node_users = manager->node_users()[while_cnode]; - for (auto &node_user : node_users) { - manager->SetEdge(node_user.first, node_user.second, switch_cnode); + if (!manager->Replace(while_cnode, switch_cnode)) { + MS_LOG(ERROR) << "replace node failed."; + return false; } } return true;