Merge pull request !4030 from laiyongqiang/replace_parametertags/v1.1.0
| @@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph, | |||||
| } | } | ||||
| } | } | ||||
| EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph); | |||||
| EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph, | |||||
| graph_list); | |||||
| } | } | ||||
| void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, | void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, | ||||
| const std::set<CNodePtr> &all_nodes, | const std::set<CNodePtr> &all_nodes, | ||||
| const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, | const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, | ||||
| NotNull<KernelGraphPtr> root_graph) { | |||||
| NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list) { | |||||
| std::vector<CNodePtr> exec_order = root_graph->execution_order(); | std::vector<CNodePtr> exec_order = root_graph->execution_order(); | ||||
| while (parameter_count->HasValidElem()) { | while (parameter_count->HasValidElem()) { | ||||
| auto [para, read, written] = parameter_count->GetOneValidElem(); | auto [para, read, written] = parameter_count->GetOneValidElem(); | ||||
| @@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete | |||||
| if (visit_source->isa<Parameter>()) { | if (visit_source->isa<Parameter>()) { | ||||
| parameter_count->AddReadCount(visit_source, read - 1); | parameter_count->AddReadCount(visit_source, read - 1); | ||||
| } | } | ||||
| // replace parameter in node | |||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| for (size_t i = 0; i < node->size(); ++i) { | for (size_t i = 0; i < node->size(); ++i) { | ||||
| if (node->input(i) == para) { | if (node->input(i) == para) { | ||||
| @@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // replace parameter in graph input | |||||
| for (auto &g : graph_list) { | |||||
| auto child_graph_inputs = g->MutableInputs(); | |||||
| std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source); | |||||
| MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph " | |||||
| << g->graph_id() << " inputs"; | |||||
| } | |||||
| } | } | ||||
| root_graph->set_execution_order(exec_order); | root_graph->set_execution_order(exec_order); | ||||
| } | } | ||||
| @@ -47,7 +47,7 @@ class AscendControlParser { | |||||
| static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list); | static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list); | ||||
| static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes, | static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes, | ||||
| const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, | const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, | ||||
| NotNull<KernelGraphPtr> root_graph); | |||||
| NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list); | |||||
| static void EraseLabel(NotNull<KernelGraphPtr> root_graph); | static void EraseLabel(NotNull<KernelGraphPtr> root_graph); | ||||
| static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg, | static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg, | ||||
| const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list, | const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list, | ||||
| @@ -153,9 +153,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo)); | HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo)); | ||||
| memo.clear(); | memo.clear(); | ||||
| AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||||
| memo.clear(); | |||||
| UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); | UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); | ||||
| memo.clear(); | memo.clear(); | ||||
| // add make_tuple to the output graph | // add make_tuple to the output graph | ||||
| @@ -178,7 +175,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| debugger_->PreExecute(root_graph); | debugger_->PreExecute(root_graph); | ||||
| } | } | ||||
| SetSummaryNodes(root_graph.get()); | SetSummaryNodes(root_graph.get()); | ||||
| // alloc mem | |||||
| // Alloc memory for child graph's inputs | |||||
| AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||||
| memo.clear(); | |||||
| // Alloc memory for root graph's inputs and node's outputs, workspace | |||||
| MemoryAlloc(root_graph.get()); | MemoryAlloc(root_graph.get()); | ||||
| // generate and load task into device | // generate and load task into device | ||||
| Load(root_graph); | Load(root_graph); | ||||
| @@ -337,6 +337,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { | if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope() | |||||
| << " index: " << index << " size: " << tensor_size; | |||||
| AnfAlgo::SetOutputAddr(address, index, item.get()); | AnfAlgo::SetOutputAddr(address, index, item.get()); | ||||
| } | } | ||||
| } | } | ||||