Browse Source

!12691 [MS][LITE]fix bug of functionalize while

From: @mengyuanli
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
41ec878ef5
3 changed files with 16 additions and 12 deletions
  1. +4
    -0
      mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc
  2. +8
    -8
      mindspore/lite/tools/optimizer/graph/functionalize_while.cc
  3. +4
    -4
      mindspore/lite/tools/optimizer/graph/while_pass.cc

+ 4
- 0
mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc View File

@@ -48,6 +48,9 @@ std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node)


void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) { void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) {
for (auto &node : func_graph->nodes()) { for (auto &node : func_graph->nodes()) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cluster_name = NodeClusterName(node); auto cluster_name = NodeClusterName(node);
auto cluster_pos = WhichCluster(cluster_name); auto cluster_pos = WhichCluster(cluster_name);
if (cluster_pos == node_clusters_.size()) { if (cluster_pos == node_clusters_.size()) {
@@ -90,6 +93,7 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g
MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret; MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret;
return ret; return ret;
} }
break;
} }
} }
} }


+ 8
- 8
mindspore/lite/tools/optimizer/graph/functionalize_while.cc View File

@@ -196,17 +196,14 @@ STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() {
} }


STATUS FunctionalizeWhile::UpdateExitNodeUser() { STATUS FunctionalizeWhile::UpdateExitNodeUser() {
auto manager = fg_->manager();
if (output_exit_nodes_.size() == 1) { if (output_exit_nodes_.size() == 1) {
auto manager = fg_->manager();
auto node_users = manager->node_users()[output_exit_nodes_[0]];
for (auto &node_user : node_users) {
if (fg_->nodes().contains(node_user.first)) {
manager->SetEdge(node_user.first, node_user.second, while_node_);
}
if (!manager->Replace(output_exit_nodes_[0], while_node_)) {
MS_LOG(ERROR) << "replace node failed.";
return RET_ERROR;
} }
} else { } else {
for (auto &node : output_exit_nodes_) { for (auto &node : output_exit_nodes_) {
auto manager = fg_->manager();
auto node_users = manager->node_users()[node]; auto node_users = manager->node_users()[node];
for (auto &node_user : node_users) { for (auto &node_user : node_users) {
// new getitem // new getitem
@@ -237,7 +234,10 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
get_item_node->set_fullname_with_scope(output_item_name); get_item_node->set_fullname_with_scope(output_item_name);
// set // set
if (fg_->nodes().contains(node_user.first)) { if (fg_->nodes().contains(node_user.first)) {
manager->SetEdge(node_user.first, node_user.second, get_item_node);
if (!manager->Replace(output_exit_nodes_[0], while_node_)) {
MS_LOG(ERROR) << "replace node failed.";
return RET_ERROR;
}
} }
} }
} }


+ 4
- 4
mindspore/lite/tools/optimizer/graph/while_pass.cc View File

@@ -86,7 +86,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); body_to_cond_inputs.emplace_back(body_output_cnode->input(i));
} }
} else { } else {
body_to_cond_inputs.emplace_back(body_output_cnode->input(1));
body_to_cond_inputs.emplace_back(body_output_cnode);
} }


// concat body to cond // concat body to cond
@@ -121,9 +121,9 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {


// create cond partial cnode // create cond partial cnode
auto manager = graph->manager(); auto manager = graph->manager();
auto node_users = manager->node_users()[while_cnode];
for (auto &node_user : node_users) {
manager->SetEdge(node_user.first, node_user.second, switch_cnode);
if (!manager->Replace(while_cnode, switch_cnode)) {
MS_LOG(ERROR) << "replace node failed.";
return false;
} }
} }
return true; return true;


Loading…
Cancel
Save