From 613aa757c3c19a152ffae348eb0f4ee8857a5db0 Mon Sep 17 00:00:00 2001 From: hwjiaorui Date: Mon, 15 Nov 2021 16:37:09 +0800 Subject: [PATCH] modify events index --- .../device/ascend/ascend_kernel_runtime.cc | 24 +++++++------- .../device/ascend/ascend_kernel_runtime.h | 2 +- .../ccsrc/runtime/device/kernel_runtime.cc | 31 +++++++++---------- .../ccsrc/runtime/device/kernel_runtime.h | 7 +++-- 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 2c5ee91ea5..f7904e789e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -779,12 +779,10 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { } std::vector last_stream_nodes; SetKernelModStream(kernels, &last_stream_nodes); - auto kernel_events = - std::pair>>, std::vector>>>(); + auto kernel_events = std::pair>>, + std::map>>>(); auto &kernel_pre_run_events = kernel_events.first; auto &kernel_post_run_events = kernel_events.second; - kernel_pre_run_events.resize(kernels.size()); - kernel_post_run_events.resize(kernels.size()); auto stream_num = stream_id_map_.size(); std::vector> kernel_hit(kernels.size(), std::vector(stream_num, false)); for (size_t i = 0; i < kernels.size(); ++i) { @@ -815,8 +813,8 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { auto event = CreateDeviceEvent(); event->set_wait_stream(wait_stream); event->set_record_stream(record_stream); - kernel_post_run_events[IntToSize(k)].emplace_back([event]() { event->RecordEvent(); }); - kernel_pre_run_events[i].emplace_back([event]() { event->WaitEvent(); }); + kernel_post_run_events[pre_cnode].emplace_back([event]() { event->RecordEvent(); }); + kernel_pre_run_events[kernel].emplace_back([event]() { event->WaitEvent(); }); } } } @@ -824,17 +822,17 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { auto pre_event = CreateDeviceEvent(); pre_event->set_wait_stream(wait_stream); pre_event->set_record_stream(stream_); - kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->RecordEvent(); }); - kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->WaitEvent(); }); + kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->RecordEvent(); }); + kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->WaitEvent(); }); } } ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes); graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); } -void AscendKernelRuntime::ProcessBoundaryEvent(const std::vector &kernels, - std::vector>> *kernel_run_events, - const std::vector &last_stream_nodes) { +void AscendKernelRuntime::ProcessBoundaryEvent( + const std::vector &kernels, std::map>> *kernel_run_events, + const std::vector &last_stream_nodes) { for (auto &i : last_stream_nodes) { if (i >= kernels.size()) { MS_LOG(ERROR) << "Node index exceed kernel size."; @@ -865,8 +863,8 @@ void AscendKernelRuntime::ProcessBoundaryEvent(const std::vector &kern auto record_stream = stream_id_map_[id]; post_event->set_wait_stream(stream_); post_event->set_record_stream(record_stream); - (*kernel_run_events)[i].emplace_back([post_event]() { post_event->RecordEvent(); }); - (*kernel_run_events)[i].emplace_back([post_event]() { post_event->WaitEvent(); }); + (*kernel_run_events)[kernel].emplace_back([post_event]() { post_event->RecordEvent(); }); + (*kernel_run_events)[kernel].emplace_back([post_event]() { post_event->WaitEvent(); }); } } } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 6a3022c470..5fbec2618d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -46,7 +46,7 @@ class AscendKernelRuntime : public KernelRuntime { void GenKernelEvents(const session::KernelGraph &graph) override; void SetKernelModStream(const std::vector &kernels, std::vector *last_stream_nodes); void ProcessBoundaryEvent(const std::vector &kernels, - std::vector>> *kernel_run_events, + std::map>> *kernel_run_events, const std::vector &last_stream_nodes); bool GenDynamicKernel(const session::KernelGraph &graph) override; bool RunDynamicKernelAsync(const session::KernelGraph &graph) override; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index e48e28804c..9fe8cc8897 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -1138,12 +1138,10 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) { return; } - auto kernel_events = - std::pair>>, std::vector>>>(); + auto kernel_events = std::pair>>, + std::map>>>(); auto &kernel_pre_run_events = kernel_events.first; auto &kernel_post_run_events = kernel_events.second; - kernel_pre_run_events.resize(kernels.size()); - kernel_post_run_events.resize(kernels.size()); for (size_t i = 0; i < kernels.size(); ++i) { auto &kernel = kernels[i]; if (!AnfAlgo::IsCommunicationOp(kernel)) { @@ -1157,11 +1155,11 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { pre_event->set_record_stream(stream_); post_event->set_wait_stream(stream_); post_event->set_record_stream(communication_stream_); - kernel_pre_run_events[i].emplace_back([pre_event]() { + kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->RecordEvent(); pre_event->WaitEvent(); }); - kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); + kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->RecordEvent(); }); bool found_nearest_child = false; for (size_t j = i + 1; j < kernels.size(); ++j) { auto &child = kernels[j]; @@ -1178,12 +1176,12 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { } } if (found_nearest_child) { - kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); + kernel_pre_run_events[child].emplace_back([post_event]() { post_event->WaitEvent(); }); break; } } if (!found_nearest_child) { - kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); + kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->WaitEvent(); }); } } graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); @@ -1243,12 +1241,13 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList } } -void KernelRuntime::LaunchKernelEvent(const std::vector>> &kernel_events, - size_t index) const { - if (index >= kernel_events.size()) { +void KernelRuntime::LaunchKernelEvent(const std::map>> &kernel_events, + const AnfNodePtr &node) const { + if (kernel_events.find(node) == kernel_events.end()) { return; } - for (auto &event : kernel_events[index]) { + + for (auto &event : kernel_events.at(node)) { event(); } } @@ -1511,15 +1510,15 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size() << " should be equal to the size of kernels " << kernels.size(); } - std::vector>> kernel_pre_run_events; - std::vector>> kernel_post_run_events; + std::map>> kernel_pre_run_events; + std::map>> kernel_post_run_events; auto events_iter = graph_kernel_events_map_.find(graph.graph_id()); if (events_iter != graph_kernel_events_map_.end()) { kernel_pre_run_events = events_iter->second.first; kernel_post_run_events = events_iter->second.second; } for (size_t i = 0; i < kernels.size(); ++i) { - LaunchKernelEvent(kernel_pre_run_events, i); + LaunchKernelEvent(kernel_pre_run_events, kernels[i]); if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr && dynamic_kernel_list[i]->is_dynamic_shape()) { dynamic_kernel_list[i]->InferShape(); @@ -1553,7 +1552,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock KernelLaunchProfiling(kernel->fullname_with_scope()); DebugStreamSync(kernel); } - LaunchKernelEvent(kernel_post_run_events, i); + LaunchKernelEvent(kernel_post_run_events, kernels[i]); } return true; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index bba52fe7a5..e2714c984b 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -161,7 +161,8 @@ class KernelRuntime { void AssignCommunicationMem(const session::KernelGraph &graph); bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); - void LaunchKernelEvent(const std::vector>> &run_events, size_t index) const; + void LaunchKernelEvent(const std::map>> &run_events, + const AnfNodePtr &node) const; void DebugStreamSync(const CNodePtr &kernel); static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs, const std::shared_ptr &mem_schedule = nullptr); @@ -195,8 +196,8 @@ class KernelRuntime { void *communication_stream_{nullptr}; std::shared_ptr mem_manager_{nullptr}; std::map> graph_dynamic_kernel_map_; - std::map>>, std::vector>>>> + std::map>>, + std::map>>>> graph_kernel_events_map_; MemSchedulerManager mem_scheduler_manager_; };