diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc index e93d2b6dc5..df133b9f5f 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc @@ -73,54 +73,44 @@ std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_ker return used_nodes; } -void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { - AnfNodeSet outputs_set; - for (auto out : *outputs) { - outputs_set.insert(out); - } - - AnfNodePtrList vir_outputs; - std::unordered_map eqv; - auto fg_outputs = fg->output(); - if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { - auto cnode = fg_outputs->cast(); - for (size_t i = 1; i < cnode->size(); ++i) { - vir_outputs.push_back(cnode->input(i)); +void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users, + std::vector *control_depend_nodes, std::vector *control_depend_use_index, + bool *is_only_control_depend_use, AnfNodePtr *use_out) { + for (auto &user : users) { + auto use_node = user.first; + if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { + *is_only_control_depend_use = false; + continue; + } + if (outputs_set.count(use_node) != 0) { + *use_out = use_node; + } + if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { + control_depend_nodes->push_back(use_node->cast()); + control_depend_use_index->push_back(user.second); } - } else { - vir_outputs.push_back(fg_outputs); } +} - if (vir_outputs.size() != outputs->size()) { - MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; +bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng, + std::unordered_map *eqv) { + AnfNodeSet outputs_set; + for (auto out : *outputs) { + outputs_set.insert(out); } bool has_erase_outs = false; size_t index = -1; for (auto it = outputs->begin(); it != outputs->end();) { index++; auto out = *it; - eqv[out] = vir_outputs[index]; + (*eqv)[out] = vir_outputs[index]; auto users = mng->node_users()[out]; bool is_only_control_depend_use = true; std::vector control_depend_use_index; std::vector control_depend_nodes; AnfNodePtr use_out = nullptr; - for (auto &user : users) { - auto use_node = user.first; - if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { - is_only_control_depend_use = false; - continue; - } - if (outputs_set.count(use_node) != 0) { - use_out = use_node; - } - - if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { - control_depend_nodes.push_back(use_node->cast()); - control_depend_use_index.push_back(user.second); - } - } - + SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index, + &is_only_control_depend_use, &use_out); if (is_only_control_depend_use && !control_depend_nodes.empty()) { MS_EXCEPTION_IF_NULL(use_out); it = outputs->erase(it); @@ -142,8 +132,27 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con it++; } } + return has_erase_outs; +} + +void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { + AnfNodePtrList vir_outputs; + std::unordered_map eqv; + auto fg_outputs = fg->output(); + if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { + auto cnode = fg_outputs->cast(); + for (size_t i = 1; i < cnode->size(); ++i) { + vir_outputs.push_back(cnode->input(i)); + } + } else { + vir_outputs.push_back(fg_outputs); + } + + if (vir_outputs.size() != outputs->size()) { + MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; + } - if (!has_erase_outs) { + if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) { return; }