diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index f4d91f5858..07a2e9b7d0 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull root_graph, } } - EraseAssign(std::make_shared(parameter_count), all_nodes, para_to_written_node, root_graph); + EraseAssign(std::make_shared(parameter_count), all_nodes, para_to_written_node, root_graph, + graph_list); } void AscendControlParser::EraseAssign(std::shared_ptr parameter_count, const std::set &all_nodes, const std::map ¶_to_written_node, - NotNull root_graph) { + NotNull root_graph, const std::set &graph_list) { std::vector exec_order = root_graph->execution_order(); while (parameter_count->HasValidElem()) { auto [para, read, written] = parameter_count->GetOneValidElem(); @@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr paramete if (visit_source->isa()) { parameter_count->AddReadCount(visit_source, read - 1); } + + // replace parameter in node for (auto &node : all_nodes) { for (size_t i = 0; i < node->size(); ++i) { if (node->input(i) == para) { @@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr paramete } } } + + // replace parameter in graph input + for (auto &g : graph_list) { + auto child_graph_inputs = g->MutableInputs(); + std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source); + MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph " + << g->graph_id() << " inputs"; + } } root_graph->set_execution_order(exec_order); } diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index 466b997db3..ddecd0a81a 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -47,7 +47,7 @@ class AscendControlParser { static void EraseParameter(NotNull root_graph, const std::set &graph_list); static void EraseAssign(std::shared_ptr parameter_count, const std::set &all_nodes, const std::map ¶_to_written_node, - NotNull root_graph); + NotNull root_graph, const std::set &graph_list); static void EraseLabel(NotNull root_graph); static void ChildGraphDataAssign(NotNull kg, const NotNull> *> link_list, diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index a5baddd517..b5069c4acb 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -153,9 +153,6 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo)); memo.clear(); - AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo)); - memo.clear(); - UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); memo.clear(); // add make_tuple to the output graph @@ -178,7 +175,10 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { debugger_->PreExecute(root_graph); } SetSummaryNodes(root_graph.get()); - // alloc mem + // Alloc memory for child graph's inputs + AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + // Alloc memory for root graph's inputs and node's outputs, workspace MemoryAlloc(root_graph.get()); // generate and load task into device Load(root_graph); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 8b3c8f0988..994c8a82fb 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -337,6 +337,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; } + MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope() + << " index: " << index << " size: " << tensor_size; AnfAlgo::SetOutputAddr(address, index, item.get()); } }