|
|
|
@@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph); |
|
|
|
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph, |
|
|
|
graph_list); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, |
|
|
|
const std::set<CNodePtr> &all_nodes, |
|
|
|
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, |
|
|
|
NotNull<KernelGraphPtr> root_graph) { |
|
|
|
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list) { |
|
|
|
std::vector<CNodePtr> exec_order = root_graph->execution_order(); |
|
|
|
while (parameter_count->HasValidElem()) { |
|
|
|
auto [para, read, written] = parameter_count->GetOneValidElem(); |
|
|
|
@@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete |
|
|
|
if (visit_source->isa<Parameter>()) { |
|
|
|
parameter_count->AddReadCount(visit_source, read - 1); |
|
|
|
} |
|
|
|
|
|
|
|
// replace parameter in node |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
for (size_t i = 0; i < node->size(); ++i) { |
|
|
|
if (node->input(i) == para) { |
|
|
|
@@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// replace parameter in graph input |
|
|
|
for (auto &g : graph_list) { |
|
|
|
auto child_graph_inputs = g->MutableInputs(); |
|
|
|
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source); |
|
|
|
MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph " |
|
|
|
<< g->graph_id() << " inputs"; |
|
|
|
} |
|
|
|
} |
|
|
|
root_graph->set_execution_order(exec_order); |
|
|
|
} |
|
|
|
|