From b6035cf1a135d130a3d2cc229ffe449bbef4294a Mon Sep 17 00:00:00 2001 From: tanghuikang Date: Wed, 27 Oct 2021 16:20:11 +0800 Subject: [PATCH] Ascend swap support communication op --- .../ccsrc/backend/session/ascend_session.cc | 8 +- mindspore/ccsrc/backend/session/executor.cc | 2 +- .../ccsrc/backend/session/session_basic.cc | 4 +- .../ascend/distribute/ascend_collective.cc | 2 +- .../ccsrc/runtime/device/kernel_runtime.cc | 197 +++++++++--------- .../ccsrc/runtime/device/kernel_runtime.h | 17 +- 6 files changed, 118 insertions(+), 112 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index e1aad42da4..8b973cfa77 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -294,7 +294,7 @@ void AscendSession::LoadInputData(const std::shared_ptr &kernel_gra MS_EXCEPTION_IF_NULL(kernel_graph); device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(kernel_graph); auto &input_nodes = kernel_graph->input_nodes(); - if (device::KernelRuntime::use_mem_scheduler()) { + if (device::KernelRuntime::UseMemScheduler()) { kernel_graph->SetInputTensors(inputs); return; } @@ -539,7 +539,7 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { } else { // alloc memory, including static memory and dynamic memory MemoryAlloc(graph.get()); - if (!device::KernelRuntime::use_mem_scheduler()) { + if (!device::KernelRuntime::UseMemScheduler()) { AnfAlgo::CacheAddrForGraph(graph); } // generate and load task info to device if it is sink mode @@ -576,7 +576,7 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { // optimize graph HardwareOptimize(child_graph); // assign static memory of parameters - if (!device::KernelRuntime::use_mem_scheduler()) { + if (!device::KernelRuntime::UseMemScheduler()) { auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); runtime_instance->AssignStaticMemoryInput(*child_graph); @@ -1800,7 +1800,7 @@ void AscendSession::ExecuteAllTaskInQueue() { void AscendSession::UpdateOutputTensors(const VectorRef *outputs, const std::map &tensor_to_node, std::map *) { - if (device::KernelRuntime::use_mem_scheduler()) { + if (device::KernelRuntime::UseMemScheduler()) { return; } MS_EXCEPTION_IF_NULL(outputs); diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index cf0b30afaa..63addb29d9 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -130,7 +130,7 @@ void RunGraphTask::Run() { return; } graph->ResetGraphRunningStatus(); - if (device::KernelRuntime::use_mem_scheduler()) { + if (device::KernelRuntime::UseMemScheduler()) { graph->SetOutputNodeToTensor(node_to_tensor_); } try { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 3642441d0c..4086ae4bf0 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1691,8 +1691,6 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]"; outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, node_to_tensor)); } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); } void SessionBasic::UpdateOutputTensors(const VectorRef *outputs, @@ -1700,7 +1698,7 @@ void SessionBasic::UpdateOutputTensors(const VectorRef *outputs, std::map *) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (device::KernelRuntime::use_mem_scheduler()) { + if (device::KernelRuntime::UseMemScheduler()) { return; } MS_EXCEPTION_IF_NULL(outputs); diff --git a/mindspore/ccsrc/runtime/device/ascend/distribute/ascend_collective.cc b/mindspore/ccsrc/runtime/device/ascend/distribute/ascend_collective.cc index 0270f52725..ed7b3a9321 100644 --- a/mindspore/ccsrc/runtime/device/ascend/distribute/ascend_collective.cc +++ b/mindspore/ccsrc/runtime/device/ascend/distribute/ascend_collective.cc @@ -49,7 +49,7 @@ bool HcclCollectiveGroup::InitCollective() { << "Loading libascend_collective.so failed. Many reasons could cause this:\n1.libascend_collective.so is not " "installed.\n2.hccl is not " "installed or found.\n3.mpi is not installed or found, please check if lib files of OpenMPI is added to " - "LD_LIBRATY_PATH."; + "LD_LIBRARY_PATH."; } init_mpi_ = DlsymFuncObj(InitMPI, collective_handle_); finalize_mpi_ = DlsymFuncObj(FinalizeMPI, collective_handle_); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index b7abd1841d..ca426e1b58 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -98,9 +98,10 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_ void KernelRuntime::AssignMemory(const session::KernelGraph &graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (use_mem_scheduler()) { + if (UseMemScheduler()) { AssignStaticMemoryValueNode(graph); ResetNodeAddress(graph); + AssignCommunicationMem(graph); } else { MS_EXCEPTION_IF_NULL(mem_manager_); mem_manager_->ResetDynamicMemory(); @@ -110,9 +111,9 @@ void KernelRuntime::AssignMemory(const session::KernelGraph &graph) { UpdateRefNodeOutputMem(graph); } -void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, - std::vector *address_list, - std::vector *align_size_list) const { +void KernelRuntime::GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, + DeviceAddressPtrList *address_list, + std::vector *align_size_list) const { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(total_size); MS_EXCEPTION_IF_NULL(address_list); @@ -140,24 +141,19 @@ void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_ } } -void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const { +void KernelRuntime::AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const { if (!AnfAlgo::IsCommunicationOp(node)) { return; } MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); + size_t total_size = 0; - std::vector address_list; + DeviceAddressPtrList address_list; std::vector align_size_list; - RunOpGetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list); - if (address_list.empty()) { - return; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kMinInputSize) { - MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope(); + GetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list); + if (align_size_list.empty()) { + MS_LOG(WARNING) << "No inputs for " << node->fullname_with_scope(); return; } @@ -166,52 +162,53 @@ void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const } } -void KernelRuntime::RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, - std::vector *align_size_list, - std::vector *device_address_list) const { +void KernelRuntime::GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, + DeviceAddressPtrList *address_list, + std::vector *align_size_list) const { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(total_size); MS_EXCEPTION_IF_NULL(align_size_list); - MS_EXCEPTION_IF_NULL(device_address_list); - auto runtime_info = node->user_data(); - auto output_num = AnfAlgo::GetOutputTensorNum(node); - for (size_t i = 0; i < output_num; ++i) { - MS_EXCEPTION_IF_NULL(runtime_info); + MS_EXCEPTION_IF_NULL(address_list); + + const auto kernel_mod = AnfAlgo::GetKernelMod(node); + MS_EXCEPTION_IF_NULL(kernel_mod); + const auto output_size_list = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_size_list.size(); ++i) { DeviceAddressPtr address = nullptr; if (AnfAlgo::OutputAddrExist(node, i)) { address = AnfAlgo::GetMutableOutputAddr(node, i); } else { - std::string output_format = runtime_info->output_format(i); - auto output_type = runtime_info->output_type(i); - address = - CreateDeviceAddress(nullptr, runtime_info->output_tensor_size(i), output_format, output_type, {node, i}); + const std::string output_format = AnfAlgo::GetOutputFormat(node, i); + const auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); + const auto tensor_size = AnfAlgo::GetOutputTensorMemSize(node, i); + address = CreateDeviceAddress(nullptr, tensor_size, output_format, output_type, {node, i}); + AnfAlgo::SetOutputAddr(address, i, node.get()); } MS_EXCEPTION_IF_NULL(address); auto align_size = MemoryManager::GetCommonAlignSize(address->size()); *total_size += align_size; align_size_list->emplace_back(align_size); - device_address_list->emplace_back(address); + address_list->emplace_back(address); } } -void KernelRuntime::RunOpAssignCommunicationOutput(const AnfNodePtr &node) const { +void KernelRuntime::AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const { if (!AnfAlgo::IsCommunicationOp(node)) { return; } - MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); size_t total_size = 0; std::vector align_size_list; - std::vector device_address_list; - RunOpGetCommunicationOutputInfo(node, &total_size, &align_size_list, &device_address_list); - + std::vector address_list; + GetCommunicationOutputInfo(node, &total_size, &address_list, &align_size_list); if (align_size_list.empty()) { + MS_LOG(WARNING) << "No output for " << node->fullname_with_scope(); return; } - if (!mem_manager_->MallocContinuousMemFromMemPool(device_address_list, total_size, align_size_list)) { + if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list)) { MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size; } } @@ -318,8 +315,8 @@ void KernelRuntime::RunOpAssignMemory(const std::vector &inpu mem_manager_->ResetDynamicMemory(); for (const auto &node : graph.execution_order()) { - RunOpAssignCommunicationOutput(node); - RunOpAssignCommunicationInput(node); + AssignCommunicationOutputFromMemoryPool(node); + AssignCommunicationInputFromMemoryPool(node); } RunOpAssignInputMemory(input_tensors, graph); @@ -688,62 +685,6 @@ void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &n AssignWorkSpaceMem(type, node); } -void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { - auto &kernels = graph.execution_order(); - if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) { - return; - } - auto kernel_events = - std::pair>>, std::vector>>>(); - auto &kernel_pre_run_events = kernel_events.first; - auto &kernel_post_run_events = kernel_events.second; - kernel_pre_run_events.resize(kernels.size()); - kernel_post_run_events.resize(kernels.size()); - for (size_t i = 0; i < kernels.size(); ++i) { - auto &kernel = kernels[i]; - if (!AnfAlgo::IsCommunicationOp(kernel)) { - continue; - } - auto pre_event = CreateDeviceEvent(); - auto post_event = CreateDeviceEvent(); - MS_EXCEPTION_IF_NULL(pre_event); - MS_EXCEPTION_IF_NULL(post_event); - pre_event->set_wait_stream(communication_stream_); - pre_event->set_record_stream(stream_); - post_event->set_wait_stream(stream_); - post_event->set_record_stream(communication_stream_); - kernel_pre_run_events[i].emplace_back([pre_event]() { - pre_event->RecordEvent(); - pre_event->WaitEvent(); - }); - kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); - bool found_nearest_child = false; - for (size_t j = i + 1; j < kernels.size(); ++j) { - auto &child = kernels[j]; - MS_EXCEPTION_IF_NULL(child); - if (AnfAlgo::IsCommunicationOp(child)) { - continue; - } - auto input_size = child->inputs().size() - 1; - for (size_t k = 0; k < input_size; ++k) { - auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true); - if (kernel_index.first == kernel) { - found_nearest_child = true; - break; - } - } - if (found_nearest_child) { - kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); - break; - } - } - if (!found_nearest_child) { - kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); - } - } - graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); -} - void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); @@ -1174,7 +1115,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod } } -bool KernelRuntime::use_mem_scheduler() { +bool KernelRuntime::UseMemScheduler() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); if (!context_ptr->get_param(MS_CTX_ENABLE_MEM_SCHEDULER)) { @@ -1185,6 +1126,62 @@ bool KernelRuntime::use_mem_scheduler() { (context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode)); } +void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { + auto &kernels = graph.execution_order(); + if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) { + return; + } + auto kernel_events = + std::pair>>, std::vector>>>(); + auto &kernel_pre_run_events = kernel_events.first; + auto &kernel_post_run_events = kernel_events.second; + kernel_pre_run_events.resize(kernels.size()); + kernel_post_run_events.resize(kernels.size()); + for (size_t i = 0; i < kernels.size(); ++i) { + auto &kernel = kernels[i]; + if (!AnfAlgo::IsCommunicationOp(kernel)) { + continue; + } + auto pre_event = CreateDeviceEvent(); + auto post_event = CreateDeviceEvent(); + MS_EXCEPTION_IF_NULL(pre_event); + MS_EXCEPTION_IF_NULL(post_event); + pre_event->set_wait_stream(communication_stream_); + pre_event->set_record_stream(stream_); + post_event->set_wait_stream(stream_); + post_event->set_record_stream(communication_stream_); + kernel_pre_run_events[i].emplace_back([pre_event]() { + pre_event->RecordEvent(); + pre_event->WaitEvent(); + }); + kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); + bool found_nearest_child = false; + for (size_t j = i + 1; j < kernels.size(); ++j) { + auto &child = kernels[j]; + MS_EXCEPTION_IF_NULL(child); + if (AnfAlgo::IsCommunicationOp(child)) { + continue; + } + auto input_size = child->inputs().size() - 1; + for (size_t k = 0; k < input_size; ++k) { + auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true); + if (kernel_index.first == kernel) { + found_nearest_child = true; + break; + } + } + if (found_nearest_child) { + kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); + break; + } + } + if (!found_nearest_child) { + kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); + } + } + graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); +} + void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs, const std::shared_ptr &mem_scheduler) { MS_EXCEPTION_IF_NULL(cnode); @@ -1416,6 +1413,16 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr &m } } +void KernelRuntime::AssignCommunicationMem(const session::KernelGraph &graph) { + for (const auto &kernel : graph.execution_order()) { + if (!AnfAlgo::IsCommunicationOp(kernel)) { + continue; + } + AssignCommunicationInputFromMemoryPool(kernel); + AssignCommunicationOutputFromMemoryPool(kernel); + } +} + bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel, const std::shared_ptr &mem_scheduler, bool mock) { MS_EXCEPTION_IF_NULL(kernel); @@ -1465,7 +1472,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); std::shared_ptr mem_scheduler = nullptr; - if (use_mem_scheduler()) { + if (UseMemScheduler()) { mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); MS_EXCEPTION_IF_NULL(mem_scheduler); mem_scheduler->SetMemHandler(mem_manager_); @@ -1533,7 +1540,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!use_mem_scheduler()) { + if (!UseMemScheduler()) { return; } auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index b13df3f872..4834d4ecae 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -57,8 +57,8 @@ class KernelRuntime { virtual void AssignMemory(const session::KernelGraph &graph); void RunOpAssignMemory(const std::vector &input_tensors, const session::KernelGraph &graph, const std::map &tensor_to_node = {}); - void RunOpAssignCommunicationOutput(const AnfNodePtr &node) const; - void RunOpAssignCommunicationInput(const AnfNodePtr &node) const; + void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const; + void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const; void RunOpClearMemory(const session::KernelGraph &graph) const; void RunOpMallocPre(const session::KernelGraph &graph, const std::vector &input_tensors); #ifdef ENABLE_DEBUGGER @@ -94,7 +94,7 @@ class KernelRuntime { virtual void ReleaseDeviceRes() {} void set_device_id(uint32_t device_id) { device_id_ = device_id; } uint32_t device_id() { return device_id_; } - static bool use_mem_scheduler(); + static bool UseMemScheduler(); #ifdef ENABLE_DEBUGGER // set debugger @@ -152,6 +152,8 @@ class KernelRuntime { void InitGraphInputTensors(const std::shared_ptr &mem_scheduler, const session::KernelGraph &graph); void SyncNodeOutputTensors(const std::shared_ptr &mem_scheduler, const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock); + + void AssignCommunicationMem(const session::KernelGraph &graph); void AssignStaticMemoryOutput(const session::KernelGraph &graph); bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); void LaunchKernelEvent(const std::vector>> &run_events, size_t index) const; @@ -171,11 +173,10 @@ class KernelRuntime { void CheckIfSupportPSEmbeddingCache(const session::KernelGraph &graph); void CheckSparsePSEmbeddingCache(const CNodePtr &node); #endif - void RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, - std::vector *address_list, - std::vector *align_size_list) const; - void RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, std::vector *align_size_list, - std::vector *device_address_list) const; + void GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list, + std::vector *align_size_list) const; + void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list, + std::vector *align_size_list) const; protected: uint32_t device_id_{0};