From: @mengyuanli Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tongtags/v1.2.0-rc1
| @@ -48,6 +48,9 @@ std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) | |||||
| void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) { | void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) { | ||||
| for (auto &node : func_graph->nodes()) { | for (auto &node : func_graph->nodes()) { | ||||
| if (!utils::isa<CNodePtr>(node)) { | |||||
| continue; | |||||
| } | |||||
| auto cluster_name = NodeClusterName(node); | auto cluster_name = NodeClusterName(node); | ||||
| auto cluster_pos = WhichCluster(cluster_name); | auto cluster_pos = WhichCluster(cluster_name); | ||||
| if (cluster_pos == node_clusters_.size()) { | if (cluster_pos == node_clusters_.size()) { | ||||
| @@ -90,6 +93,7 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g | |||||
| MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret; | MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -196,17 +196,14 @@ STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() { | |||||
| } | } | ||||
| STATUS FunctionalizeWhile::UpdateExitNodeUser() { | STATUS FunctionalizeWhile::UpdateExitNodeUser() { | ||||
| auto manager = fg_->manager(); | |||||
| if (output_exit_nodes_.size() == 1) { | if (output_exit_nodes_.size() == 1) { | ||||
| auto manager = fg_->manager(); | |||||
| auto node_users = manager->node_users()[output_exit_nodes_[0]]; | |||||
| for (auto &node_user : node_users) { | |||||
| if (fg_->nodes().contains(node_user.first)) { | |||||
| manager->SetEdge(node_user.first, node_user.second, while_node_); | |||||
| } | |||||
| if (!manager->Replace(output_exit_nodes_[0], while_node_)) { | |||||
| MS_LOG(ERROR) << "replace node failed."; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| } else { | } else { | ||||
| for (auto &node : output_exit_nodes_) { | for (auto &node : output_exit_nodes_) { | ||||
| auto manager = fg_->manager(); | |||||
| auto node_users = manager->node_users()[node]; | auto node_users = manager->node_users()[node]; | ||||
| for (auto &node_user : node_users) { | for (auto &node_user : node_users) { | ||||
| // new getitem | // new getitem | ||||
| @@ -237,7 +234,10 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() { | |||||
| get_item_node->set_fullname_with_scope(output_item_name); | get_item_node->set_fullname_with_scope(output_item_name); | ||||
| // set | // set | ||||
| if (fg_->nodes().contains(node_user.first)) { | if (fg_->nodes().contains(node_user.first)) { | ||||
| manager->SetEdge(node_user.first, node_user.second, get_item_node); | |||||
| if (!manager->Replace(output_exit_nodes_[0], while_node_)) { | |||||
| MS_LOG(ERROR) << "replace node failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -86,7 +86,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { | |||||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); | body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); | ||||
| } | } | ||||
| } else { | } else { | ||||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(1)); | |||||
| body_to_cond_inputs.emplace_back(body_output_cnode); | |||||
| } | } | ||||
| // concat body to cond | // concat body to cond | ||||
| @@ -121,9 +121,9 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { | |||||
| // create cond partial cnode | // create cond partial cnode | ||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| auto node_users = manager->node_users()[while_cnode]; | |||||
| for (auto &node_user : node_users) { | |||||
| manager->SetEdge(node_user.first, node_user.second, switch_cnode); | |||||
| if (!manager->Replace(while_cnode, switch_cnode)) { | |||||
| MS_LOG(ERROR) << "replace node failed."; | |||||
| return false; | |||||
| } | } | ||||
| } | } | ||||
| return true; | return true; | ||||