|
|
|
@@ -73,54 +73,44 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_ker |
|
|
|
return used_nodes; |
|
|
|
} |
|
|
|
|
|
|
|
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { |
|
|
|
AnfNodeSet outputs_set; |
|
|
|
for (auto out : *outputs) { |
|
|
|
outputs_set.insert(out); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtrList vir_outputs; |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv; |
|
|
|
auto fg_outputs = fg->output(); |
|
|
|
if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { |
|
|
|
auto cnode = fg_outputs->cast<CNodePtr>(); |
|
|
|
for (size_t i = 1; i < cnode->size(); ++i) { |
|
|
|
vir_outputs.push_back(cnode->input(i)); |
|
|
|
void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users, |
|
|
|
std::vector<CNodePtr> *control_depend_nodes, std::vector<size_t> *control_depend_use_index, |
|
|
|
bool *is_only_control_depend_use, AnfNodePtr *use_out) { |
|
|
|
for (auto &user : users) { |
|
|
|
auto use_node = user.first; |
|
|
|
if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { |
|
|
|
*is_only_control_depend_use = false; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (outputs_set.count(use_node) != 0) { |
|
|
|
*use_out = use_node; |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { |
|
|
|
control_depend_nodes->push_back(use_node->cast<CNodePtr>()); |
|
|
|
control_depend_use_index->push_back(user.second); |
|
|
|
} |
|
|
|
} else { |
|
|
|
vir_outputs.push_back(fg_outputs); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (vir_outputs.size() != outputs->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; |
|
|
|
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng, |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv) { |
|
|
|
AnfNodeSet outputs_set; |
|
|
|
for (auto out : *outputs) { |
|
|
|
outputs_set.insert(out); |
|
|
|
} |
|
|
|
bool has_erase_outs = false; |
|
|
|
size_t index = -1; |
|
|
|
for (auto it = outputs->begin(); it != outputs->end();) { |
|
|
|
index++; |
|
|
|
auto out = *it; |
|
|
|
eqv[out] = vir_outputs[index]; |
|
|
|
(*eqv)[out] = vir_outputs[index]; |
|
|
|
auto users = mng->node_users()[out]; |
|
|
|
bool is_only_control_depend_use = true; |
|
|
|
std::vector<size_t> control_depend_use_index; |
|
|
|
std::vector<CNodePtr> control_depend_nodes; |
|
|
|
AnfNodePtr use_out = nullptr; |
|
|
|
for (auto &user : users) { |
|
|
|
auto use_node = user.first; |
|
|
|
if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { |
|
|
|
is_only_control_depend_use = false; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (outputs_set.count(use_node) != 0) { |
|
|
|
use_out = use_node; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { |
|
|
|
control_depend_nodes.push_back(use_node->cast<CNodePtr>()); |
|
|
|
control_depend_use_index.push_back(user.second); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index, |
|
|
|
&is_only_control_depend_use, &use_out); |
|
|
|
if (is_only_control_depend_use && !control_depend_nodes.empty()) { |
|
|
|
MS_EXCEPTION_IF_NULL(use_out); |
|
|
|
it = outputs->erase(it); |
|
|
|
@@ -142,8 +132,27 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con |
|
|
|
it++; |
|
|
|
} |
|
|
|
} |
|
|
|
return has_erase_outs; |
|
|
|
} |
|
|
|
|
|
|
|
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { |
|
|
|
AnfNodePtrList vir_outputs; |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv; |
|
|
|
auto fg_outputs = fg->output(); |
|
|
|
if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { |
|
|
|
auto cnode = fg_outputs->cast<CNodePtr>(); |
|
|
|
for (size_t i = 1; i < cnode->size(); ++i) { |
|
|
|
vir_outputs.push_back(cnode->input(i)); |
|
|
|
} |
|
|
|
} else { |
|
|
|
vir_outputs.push_back(fg_outputs); |
|
|
|
} |
|
|
|
|
|
|
|
if (vir_outputs.size() != outputs->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; |
|
|
|
} |
|
|
|
|
|
|
|
if (!has_erase_outs) { |
|
|
|
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
|