|
|
|
@@ -187,6 +187,31 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f |
|
|
|
offset += real_outs.size() - 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// remove parameter which is not used |
|
|
|
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(inputs); |
|
|
|
const auto &ori_parameter = func_graph->parameters(); |
|
|
|
auto todos = TopoSort(func_graph->get_return()); |
|
|
|
std::set<AnfNodePtr> used_param; |
|
|
|
for (auto node : todos) { |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
(void)used_param.insert(node); |
|
|
|
} |
|
|
|
} |
|
|
|
if (used_param.size() == ori_parameter.size()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
AnfNodePtrList new_parameter, new_inputs{(*inputs)[0]}; |
|
|
|
for (size_t i = 0; i < ori_parameter.size(); ++i) { |
|
|
|
if (used_param.count(ori_parameter[i])) { |
|
|
|
new_parameter.push_back(ori_parameter[i]); |
|
|
|
new_inputs.push_back((*inputs)[i + 1]); |
|
|
|
} |
|
|
|
} |
|
|
|
func_graph->set_parameters(new_parameter); |
|
|
|
*inputs = std::move(new_inputs); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes) { |
|
|
|
@@ -250,6 +275,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNod |
|
|
|
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) { |
|
|
|
std::vector<AnfNodePtr> fn_inputs{NewValueNode(sub_fg)}; |
|
|
|
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); |
|
|
|
EliminateRedundantParameters(sub_fg, &fn_inputs); |
|
|
|
auto fuse_cnode = main_fg->NewCNode(fn_inputs); |
|
|
|
fuse_cnode->set_abstract(sub_fg->output()->abstract()); |
|
|
|
Callback::Instance()->SetGraphKernelNodeKernelInfo(fuse_cnode); |
|
|
|
|