| @@ -779,12 +779,10 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { | |||
| } | |||
| std::vector<size_t> last_stream_nodes; | |||
| SetKernelModStream(kernels, &last_stream_nodes); | |||
| auto kernel_events = | |||
| std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>(); | |||
| auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>, | |||
| std::map<AnfNodePtr, std::vector<std::function<void()>>>>(); | |||
| auto &kernel_pre_run_events = kernel_events.first; | |||
| auto &kernel_post_run_events = kernel_events.second; | |||
| kernel_pre_run_events.resize(kernels.size()); | |||
| kernel_post_run_events.resize(kernels.size()); | |||
| auto stream_num = stream_id_map_.size(); | |||
| std::vector<std::vector<bool>> kernel_hit(kernels.size(), std::vector<bool>(stream_num, false)); | |||
| for (size_t i = 0; i < kernels.size(); ++i) { | |||
| @@ -815,8 +813,8 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { | |||
| auto event = CreateDeviceEvent(); | |||
| event->set_wait_stream(wait_stream); | |||
| event->set_record_stream(record_stream); | |||
| kernel_post_run_events[IntToSize(k)].emplace_back([event]() { event->RecordEvent(); }); | |||
| kernel_pre_run_events[i].emplace_back([event]() { event->WaitEvent(); }); | |||
| kernel_post_run_events[pre_cnode].emplace_back([event]() { event->RecordEvent(); }); | |||
| kernel_pre_run_events[kernel].emplace_back([event]() { event->WaitEvent(); }); | |||
| } | |||
| } | |||
| } | |||
| @@ -824,17 +822,17 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { | |||
| auto pre_event = CreateDeviceEvent(); | |||
| pre_event->set_wait_stream(wait_stream); | |||
| pre_event->set_record_stream(stream_); | |||
| kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->RecordEvent(); }); | |||
| kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->WaitEvent(); }); | |||
| kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->RecordEvent(); }); | |||
| kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->WaitEvent(); }); | |||
| } | |||
| } | |||
| ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes); | |||
| graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); | |||
| } | |||
| void AscendKernelRuntime::ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels, | |||
| std::vector<std::vector<std::function<void()>>> *kernel_run_events, | |||
| const std::vector<size_t> &last_stream_nodes) { | |||
| void AscendKernelRuntime::ProcessBoundaryEvent( | |||
| const std::vector<CNodePtr> &kernels, std::map<AnfNodePtr, std::vector<std::function<void()>>> *kernel_run_events, | |||
| const std::vector<size_t> &last_stream_nodes) { | |||
| for (auto &i : last_stream_nodes) { | |||
| if (i >= kernels.size()) { | |||
| MS_LOG(ERROR) << "Node index exceed kernel size."; | |||
| @@ -865,8 +863,8 @@ void AscendKernelRuntime::ProcessBoundaryEvent(const std::vector<CNodePtr> &kern | |||
| auto record_stream = stream_id_map_[id]; | |||
| post_event->set_wait_stream(stream_); | |||
| post_event->set_record_stream(record_stream); | |||
| (*kernel_run_events)[i].emplace_back([post_event]() { post_event->RecordEvent(); }); | |||
| (*kernel_run_events)[i].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| (*kernel_run_events)[kernel].emplace_back([post_event]() { post_event->RecordEvent(); }); | |||
| (*kernel_run_events)[kernel].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| } | |||
| } | |||
| } | |||
| @@ -46,7 +46,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| void GenKernelEvents(const session::KernelGraph &graph) override; | |||
| void SetKernelModStream(const std::vector<CNodePtr> &kernels, std::vector<size_t> *last_stream_nodes); | |||
| void ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels, | |||
| std::vector<std::vector<std::function<void()>>> *kernel_run_events, | |||
| std::map<AnfNodePtr, std::vector<std::function<void()>>> *kernel_run_events, | |||
| const std::vector<size_t> &last_stream_nodes); | |||
| bool GenDynamicKernel(const session::KernelGraph &graph) override; | |||
| bool RunDynamicKernelAsync(const session::KernelGraph &graph) override; | |||
| @@ -1138,12 +1138,10 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { | |||
| if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) { | |||
| return; | |||
| } | |||
| auto kernel_events = | |||
| std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>(); | |||
| auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>, | |||
| std::map<AnfNodePtr, std::vector<std::function<void()>>>>(); | |||
| auto &kernel_pre_run_events = kernel_events.first; | |||
| auto &kernel_post_run_events = kernel_events.second; | |||
| kernel_pre_run_events.resize(kernels.size()); | |||
| kernel_post_run_events.resize(kernels.size()); | |||
| for (size_t i = 0; i < kernels.size(); ++i) { | |||
| auto &kernel = kernels[i]; | |||
| if (!AnfAlgo::IsCommunicationOp(kernel)) { | |||
| @@ -1157,11 +1155,11 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { | |||
| pre_event->set_record_stream(stream_); | |||
| post_event->set_wait_stream(stream_); | |||
| post_event->set_record_stream(communication_stream_); | |||
| kernel_pre_run_events[i].emplace_back([pre_event]() { | |||
| kernel_pre_run_events[kernel].emplace_back([pre_event]() { | |||
| pre_event->RecordEvent(); | |||
| pre_event->WaitEvent(); | |||
| }); | |||
| kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); | |||
| kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->RecordEvent(); }); | |||
| bool found_nearest_child = false; | |||
| for (size_t j = i + 1; j < kernels.size(); ++j) { | |||
| auto &child = kernels[j]; | |||
| @@ -1178,12 +1176,12 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { | |||
| } | |||
| } | |||
| if (found_nearest_child) { | |||
| kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| kernel_pre_run_events[child].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| break; | |||
| } | |||
| } | |||
| if (!found_nearest_child) { | |||
| kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| } | |||
| } | |||
| graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); | |||
| @@ -1243,12 +1241,13 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList | |||
| } | |||
| } | |||
| void KernelRuntime::LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &kernel_events, | |||
| size_t index) const { | |||
| if (index >= kernel_events.size()) { | |||
| void KernelRuntime::LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &kernel_events, | |||
| const AnfNodePtr &node) const { | |||
| if (kernel_events.find(node) == kernel_events.end()) { | |||
| return; | |||
| } | |||
| for (auto &event : kernel_events[index]) { | |||
| for (auto &event : kernel_events.at(node)) { | |||
| event(); | |||
| } | |||
| } | |||
| @@ -1511,15 +1510,15 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock | |||
| MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size() | |||
| << " should be equal to the size of kernels " << kernels.size(); | |||
| } | |||
| std::vector<std::vector<std::function<void()>>> kernel_pre_run_events; | |||
| std::vector<std::vector<std::function<void()>>> kernel_post_run_events; | |||
| std::map<AnfNodePtr, std::vector<std::function<void()>>> kernel_pre_run_events; | |||
| std::map<AnfNodePtr, std::vector<std::function<void()>>> kernel_post_run_events; | |||
| auto events_iter = graph_kernel_events_map_.find(graph.graph_id()); | |||
| if (events_iter != graph_kernel_events_map_.end()) { | |||
| kernel_pre_run_events = events_iter->second.first; | |||
| kernel_post_run_events = events_iter->second.second; | |||
| } | |||
| for (size_t i = 0; i < kernels.size(); ++i) { | |||
| LaunchKernelEvent(kernel_pre_run_events, i); | |||
| LaunchKernelEvent(kernel_pre_run_events, kernels[i]); | |||
| if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr && | |||
| dynamic_kernel_list[i]->is_dynamic_shape()) { | |||
| dynamic_kernel_list[i]->InferShape(); | |||
| @@ -1553,7 +1552,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock | |||
| KernelLaunchProfiling(kernel->fullname_with_scope()); | |||
| DebugStreamSync(kernel); | |||
| } | |||
| LaunchKernelEvent(kernel_post_run_events, i); | |||
| LaunchKernelEvent(kernel_post_run_events, kernels[i]); | |||
| } | |||
| return true; | |||
| @@ -161,7 +161,8 @@ class KernelRuntime { | |||
| void AssignCommunicationMem(const session::KernelGraph &graph); | |||
| bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); | |||
| void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index) const; | |||
| void LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &run_events, | |||
| const AnfNodePtr &node) const; | |||
| void DebugStreamSync(const CNodePtr &kernel); | |||
| static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs, | |||
| const std::shared_ptr<MemScheduler> &mem_schedule = nullptr); | |||
| @@ -195,8 +196,8 @@ class KernelRuntime { | |||
| void *communication_stream_{nullptr}; | |||
| std::shared_ptr<MemoryManager> mem_manager_{nullptr}; | |||
| std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_; | |||
| std::map<uint32_t, | |||
| std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>> | |||
| std::map<uint32_t, std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>, | |||
| std::map<AnfNodePtr, std::vector<std::function<void()>>>>> | |||
| graph_kernel_events_map_; | |||
| MemSchedulerManager mem_scheduler_manager_; | |||
| }; | |||