| @@ -73,54 +73,44 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_ker | |||||
| return used_nodes; | return used_nodes; | ||||
| } | } | ||||
| void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { | |||||
| AnfNodeSet outputs_set; | |||||
| for (auto out : *outputs) { | |||||
| outputs_set.insert(out); | |||||
| } | |||||
| AnfNodePtrList vir_outputs; | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> eqv; | |||||
| auto fg_outputs = fg->output(); | |||||
| if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { | |||||
| auto cnode = fg_outputs->cast<CNodePtr>(); | |||||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||||
| vir_outputs.push_back(cnode->input(i)); | |||||
| void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users, | |||||
| std::vector<CNodePtr> *control_depend_nodes, std::vector<size_t> *control_depend_use_index, | |||||
| bool *is_only_control_depend_use, AnfNodePtr *use_out) { | |||||
| for (auto &user : users) { | |||||
| auto use_node = user.first; | |||||
| if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { | |||||
| *is_only_control_depend_use = false; | |||||
| continue; | |||||
| } | |||||
| if (outputs_set.count(use_node) != 0) { | |||||
| *use_out = use_node; | |||||
| } | |||||
| if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { | |||||
| control_depend_nodes->push_back(use_node->cast<CNodePtr>()); | |||||
| control_depend_use_index->push_back(user.second); | |||||
| } | } | ||||
| } else { | |||||
| vir_outputs.push_back(fg_outputs); | |||||
| } | } | ||||
| } | |||||
| if (vir_outputs.size() != outputs->size()) { | |||||
| MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; | |||||
| bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv) { | |||||
| AnfNodeSet outputs_set; | |||||
| for (auto out : *outputs) { | |||||
| outputs_set.insert(out); | |||||
| } | } | ||||
| bool has_erase_outs = false; | bool has_erase_outs = false; | ||||
| size_t index = -1; | size_t index = -1; | ||||
| for (auto it = outputs->begin(); it != outputs->end();) { | for (auto it = outputs->begin(); it != outputs->end();) { | ||||
| index++; | index++; | ||||
| auto out = *it; | auto out = *it; | ||||
| eqv[out] = vir_outputs[index]; | |||||
| (*eqv)[out] = vir_outputs[index]; | |||||
| auto users = mng->node_users()[out]; | auto users = mng->node_users()[out]; | ||||
| bool is_only_control_depend_use = true; | bool is_only_control_depend_use = true; | ||||
| std::vector<size_t> control_depend_use_index; | std::vector<size_t> control_depend_use_index; | ||||
| std::vector<CNodePtr> control_depend_nodes; | std::vector<CNodePtr> control_depend_nodes; | ||||
| AnfNodePtr use_out = nullptr; | AnfNodePtr use_out = nullptr; | ||||
| for (auto &user : users) { | |||||
| auto use_node = user.first; | |||||
| if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { | |||||
| is_only_control_depend_use = false; | |||||
| continue; | |||||
| } | |||||
| if (outputs_set.count(use_node) != 0) { | |||||
| use_out = use_node; | |||||
| } | |||||
| if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { | |||||
| control_depend_nodes.push_back(use_node->cast<CNodePtr>()); | |||||
| control_depend_use_index.push_back(user.second); | |||||
| } | |||||
| } | |||||
| SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index, | |||||
| &is_only_control_depend_use, &use_out); | |||||
| if (is_only_control_depend_use && !control_depend_nodes.empty()) { | if (is_only_control_depend_use && !control_depend_nodes.empty()) { | ||||
| MS_EXCEPTION_IF_NULL(use_out); | MS_EXCEPTION_IF_NULL(use_out); | ||||
| it = outputs->erase(it); | it = outputs->erase(it); | ||||
| @@ -142,8 +132,27 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con | |||||
| it++; | it++; | ||||
| } | } | ||||
| } | } | ||||
| return has_erase_outs; | |||||
| } | |||||
| void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { | |||||
| AnfNodePtrList vir_outputs; | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> eqv; | |||||
| auto fg_outputs = fg->output(); | |||||
| if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { | |||||
| auto cnode = fg_outputs->cast<CNodePtr>(); | |||||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||||
| vir_outputs.push_back(cnode->input(i)); | |||||
| } | |||||
| } else { | |||||
| vir_outputs.push_back(fg_outputs); | |||||
| } | |||||
| if (vir_outputs.size() != outputs->size()) { | |||||
| MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; | |||||
| } | |||||
| if (!has_erase_outs) { | |||||
| if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) { | |||||
| return; | return; | ||||
| } | } | ||||