From d6aa3a2d5dbdb1086c0d84f3428d37d35275dfaa Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Wed, 24 Feb 2021 17:24:47 +0800 Subject: [PATCH] fix bug of while pass --- .../graph/functionalize_control_op_pass.cc | 4 ++++ .../tools/optimizer/graph/functionalize_while.cc | 16 ++++++++-------- .../lite/tools/optimizer/graph/while_pass.cc | 8 ++++---- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc index fa54db1a0e..90ef51b643 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc @@ -49,6 +49,9 @@ std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) { for (auto &node : func_graph->nodes()) { + if (!utils::isa(node)) { + continue; + } auto cluster_name = NodeClusterName(node); auto cluster_pos = WhichCluster(cluster_name); if (cluster_pos == node_clusters_.size()) { @@ -91,6 +94,7 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret; return ret; } + break; } } } diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_while.cc b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc index 2d1d8b53d4..bbf58e5f92 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_while.cc +++ b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc @@ -204,17 +204,14 @@ STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() { } STATUS FunctionalizeWhile::UpdateExitNodeUser() { + auto manager = fg_->manager(); if (output_exit_nodes_.size() == 1) { - auto manager = fg_->manager(); - auto node_users = manager->node_users()[output_exit_nodes_[0]]; - for (auto &node_user : node_users) { - if (fg_->nodes().contains(node_user.first)) { - manager->SetEdge(node_user.first, node_user.second, while_node_); - } + if (!manager->Replace(output_exit_nodes_[0], while_node_)) { + MS_LOG(ERROR) << "replace node failed."; + return RET_ERROR; } } else { for (auto &node : output_exit_nodes_) { - auto manager = fg_->manager(); auto node_users = manager->node_users()[node]; for (auto &node_user : node_users) { // new getitem @@ -245,7 +242,10 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() { get_item_node->set_fullname_with_scope(output_item_name); // set if (fg_->nodes().contains(node_user.first)) { - manager->SetEdge(node_user.first, node_user.second, get_item_node); + if (!manager->Replace(output_exit_nodes_[0], while_node_)) { + MS_LOG(ERROR) << "replace node failed."; + return RET_ERROR; + } } } } diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.cc b/mindspore/lite/tools/optimizer/graph/while_pass.cc index 486f77f7bc..d2088606e4 100644 --- a/mindspore/lite/tools/optimizer/graph/while_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/while_pass.cc @@ -106,7 +106,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); } } else { - body_to_cond_inputs.emplace_back(body_output_cnode->input(1)); + body_to_cond_inputs.emplace_back(body_output_cnode); } // concat body to cond @@ -141,9 +141,9 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { // create cond partial cnode auto manager = graph->manager(); - auto node_users = manager->node_users()[while_cnode]; - for (auto &node_user : node_users) { - manager->SetEdge(node_user.first, node_user.second, switch_cnode); + if (!manager->Replace(while_cnode, switch_cnode)) { + MS_LOG(ERROR) << "replace node failed."; + return false; } } return true;