|
|
|
@@ -261,17 +261,16 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
EraseAssign(all_nodes, para_to_written_node, root_graph); |
|
|
|
root_graph->set_execution_order(exec_order); |
|
|
|
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes, |
|
|
|
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, |
|
|
|
const std::set<CNodePtr> &all_nodes, |
|
|
|
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, |
|
|
|
NotNull<KernelGraphPtr> root_graph) { |
|
|
|
std::vector<CNodePtr> exec_order = root_graph->execution_order(); |
|
|
|
ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; }); |
|
|
|
while (parameter_count.HasValidElem()) { |
|
|
|
auto [para, read, written] = parameter_count.GetOneValidElem(); |
|
|
|
while (parameter_count->HasValidElem()) { |
|
|
|
auto [para, read, written] = parameter_count->GetOneValidElem(); |
|
|
|
MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; |
|
|
|
auto assign_iter = para_to_written_node.find(para); |
|
|
|
if (assign_iter == para_to_written_node.end()) { |
|
|
|
@@ -280,7 +279,7 @@ void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes, |
|
|
|
auto &assign_node = assign_iter->second; |
|
|
|
MS_EXCEPTION_IF_NULL(assign_node); |
|
|
|
if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) { |
|
|
|
parameter_count.EraseElem(para); |
|
|
|
parameter_count->EraseElem(para); |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); |
|
|
|
@@ -288,10 +287,10 @@ void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes, |
|
|
|
auto source = assign_node->input(kCNodeAssignSource); |
|
|
|
MS_EXCEPTION_IF_NULL(source); |
|
|
|
auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first; |
|
|
|
parameter_count.AddWriteCount(para, -1); |
|
|
|
parameter_count.AddReadCount(para, -1); |
|
|
|
parameter_count->AddWriteCount(para, -1); |
|
|
|
parameter_count->AddReadCount(para, -1); |
|
|
|
if (visit_source->isa<Parameter>()) { |
|
|
|
parameter_count.AddReadCount(visit_source, read - 1); |
|
|
|
parameter_count->AddReadCount(visit_source, read - 1); |
|
|
|
} |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
for (size_t i = 0; i < node->size(); ++i) { |
|
|
|
@@ -302,6 +301,7 @@ void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
root_graph->set_execution_order(exec_order); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::EraseLabel(NotNull<KernelGraphPtr> root_graph) { |
|
|
|
|