From: @kisnwang Reviewed-by: Signed-off-by:pull/13534/MERGE
| @@ -25,8 +25,8 @@ namespace kernel { | |||
| bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| MS_LOG(INFO) << "HcclAllReduce launch"; | |||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||
| MS_LOG(ERROR) << "AllReduce input output size must be 1"; | |||
| if (inputs.empty() || outputs.empty()) { | |||
| MS_LOG(ERROR) << "Invalid AllReduce input output size(" << inputs.size() << ", " << outputs.size() << ")."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| @@ -475,7 +475,8 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| std::string backend = MsContext::GetInstance()->backend_policy(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (func_graph->ContainMultiTarget()) { | |||
| auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| if (func_graph->ContainMultiTarget() || !task_sink) { | |||
| bc_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); | |||
| context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false); | |||
| @@ -35,6 +35,7 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/ascend/profiling/profiling_utils.h" | |||
| #include "runtime/device/ascend/ascend_memory_manager.h" | |||
| #include "runtime/device/ascend/ascend_event.h" | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| #include "toolchain/adx_datadump_server.h" | |||
| #include "utils/trace_base.h" | |||
| @@ -154,7 +155,7 @@ void AscendKernelRuntime::ClearGraphModelMap() { | |||
| DumpJsonParser::GetInstance().PrintUnusedKernel(); | |||
| graph_dynamic_kernel_map_.clear(); | |||
| graph_kernel_events_map_.clear(); | |||
| for (auto &iter : graph_model_map_) { | |||
| MS_LOG(INFO) << "Ge UnloadModel " << iter.first; | |||
| auto ret = ModelRunner::Instance().UnloadModel(iter.first); | |||
| @@ -186,7 +187,10 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std | |||
| MS_LOG(DEBUG) << "Start Clear graph:" << graph_id << " dynamic kernel"; | |||
| graph_dynamic_kernel_map_.erase(dynamic_kernel_iter); | |||
| } | |||
| auto events_iter = graph_kernel_events_map_.find(graph_id); | |||
| if (events_iter != graph_kernel_events_map_.end()) { | |||
| graph_kernel_events_map_.erase(events_iter); | |||
| } | |||
| MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; | |||
| if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) { | |||
| MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id; | |||
| @@ -340,9 +344,9 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size | |||
| bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { | |||
| if (!is_task_sink) { | |||
| GenKernelEvents(graph); | |||
| return true; | |||
| } | |||
| // Do HcomExecutorInitialize | |||
| if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) { | |||
| MS_LOG(ERROR) << "Init Hccl Executor Failed"; | |||
| @@ -357,6 +361,58 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { | |||
| return true; | |||
| } | |||
| void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto &kernels = graph->execution_order(); | |||
| if (kernels.empty()) { | |||
| return; | |||
| } | |||
| auto kernel_events = | |||
| std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>(); | |||
| auto &kernel_pre_run_events = kernel_events.first; | |||
| auto &kernel_post_run_events = kernel_events.second; | |||
| kernel_pre_run_events.resize(kernels.size()); | |||
| kernel_post_run_events.resize(kernels.size()); | |||
| for (size_t i = 0; i < kernels.size(); ++i) { | |||
| auto &kernel = kernels[i]; | |||
| if (!AnfAlgo::IsCommunicationOp(kernel)) { | |||
| continue; | |||
| } | |||
| auto pre_event = std::make_shared<AscendEvent>(); | |||
| auto post_event = std::make_shared<AscendEvent>(); | |||
| pre_event->set_wait_stream(communication_stream_); | |||
| pre_event->set_record_stream(stream_); | |||
| post_event->set_wait_stream(stream_); | |||
| post_event->set_record_stream(communication_stream_); | |||
| kernel_pre_run_events[i].emplace_back([pre_event]() { | |||
| pre_event->RecordEvent(); | |||
| pre_event->WaitEvent(); | |||
| }); | |||
| kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); | |||
| bool found_nearest_child = false; | |||
| for (size_t j = i + 1; j < kernels.size(); ++j) { | |||
| auto &child = kernels[j]; | |||
| MS_EXCEPTION_IF_NULL(child); | |||
| auto input_size = child->inputs().size() - 1; | |||
| for (size_t k = 0; k < input_size; ++k) { | |||
| auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0); | |||
| if (kernel_index.first == kernel) { | |||
| found_nearest_child = true; | |||
| break; | |||
| } | |||
| } | |||
| if (found_nearest_child) { | |||
| kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| break; | |||
| } | |||
| } | |||
| if (!found_nearest_child) { | |||
| kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); | |||
| } | |||
| } | |||
| graph_kernel_events_map_[graph->graph_id()] = std::move(kernel_events); | |||
| } | |||
| bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "GenDynamicKernel start"; | |||
| @@ -374,7 +430,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) { | |||
| dynamic_kernel->Initialize(); | |||
| dynamic_kernels.emplace_back(dynamic_kernel); | |||
| } | |||
| graph_dynamic_kernel_map_[graph->graph_id()] = dynamic_kernels; | |||
| graph_dynamic_kernel_map_[graph->graph_id()] = std::move(dynamic_kernels); | |||
| MS_LOG(INFO) << "GenDynamicKernel end"; | |||
| return true; | |||
| } | |||
| @@ -852,8 +908,9 @@ bool AscendKernelRuntime::HcclInit() { | |||
| MS_LOG(ERROR) << "Hcom init failed."; | |||
| return false; | |||
| } | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| MS_LOG(INFO) << "PyNative hccl init"; | |||
| auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || !task_sink) { | |||
| MS_LOG(INFO) << "Hccl comm init."; | |||
| return kernel::HcclContext::GetInstance().InitHccl(); | |||
| } | |||
| return true; | |||
| @@ -67,6 +67,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| bool KernelMemNotReuse(const AnfNodePtr &node) override; | |||
| void KernelLaunchProfiling(const std::string &kernel_name) override; | |||
| void GenKernelEvents(const session::KernelGraph *graph) override; | |||
| private: | |||
| bool InitDevice(); | |||
| @@ -31,6 +31,7 @@ | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/utils.h" | |||
| #include "frontend/parallel/context.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| @@ -920,6 +921,16 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList | |||
| } | |||
| } | |||
| void KernelRuntime::LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &kernel_events, | |||
| size_t index) { | |||
| if (index >= kernel_events.size()) { | |||
| return; | |||
| } | |||
| for (auto &event : kernel_events[index]) { | |||
| event(); | |||
| } | |||
| } | |||
| bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||
| const auto &kernels = graph.execution_order(); | |||
| std::vector<DynamicKernelPtr> dynamic_kernel_list; | |||
| @@ -931,12 +942,21 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||
| MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size() | |||
| << " should be equal to the size of kernels " << kernels.size(); | |||
| } | |||
| std::vector<std::vector<std::function<void()>>> kernel_pre_run_events; | |||
| std::vector<std::vector<std::function<void()>>> kernel_post_run_events; | |||
| auto events_iter = graph_kernel_events_map_.find(graph.graph_id()); | |||
| if (events_iter != graph_kernel_events_map_.end()) { | |||
| kernel_pre_run_events = events_iter->second.first; | |||
| kernel_post_run_events = events_iter->second.second; | |||
| } | |||
| for (size_t i = 0; i < kernels.size(); ++i) { | |||
| LaunchKernelEvent(kernel_pre_run_events, i); | |||
| if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr && | |||
| dynamic_kernel_list[i]->is_dynamic_shape()) { | |||
| dynamic_kernel_list[i]->InferShape(); | |||
| dynamic_kernel_list[i]->UpdateArgs(); | |||
| dynamic_kernel_list[i]->Execute(); | |||
| if (!SyncStream()) { | |||
| MS_LOG(ERROR) << "SyncStream failed"; | |||
| return false; | |||
| @@ -958,20 +978,23 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||
| } | |||
| continue; | |||
| } | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| bool ret; | |||
| if (AnfAlgo::IsCommunicationOp(kernel)) { | |||
| ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_); | |||
| } else { | |||
| ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| } | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launch kernel failed."; | |||
| return false; | |||
| } | |||
| KernelLaunchProfiling(kernels[i]->fullname_with_scope()); | |||
| } | |||
| LaunchKernelEvent(kernel_post_run_events, i); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <map> | |||
| #include <utility> | |||
| #include <unordered_set> | |||
| #include "runtime/device/device_address.h" | |||
| #include "ir/tensor.h" | |||
| @@ -132,10 +133,12 @@ class KernelRuntime { | |||
| void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); | |||
| virtual void KernelLaunchProfiling(const std::string &kernel_name) {} | |||
| virtual void GenKernelEvents(const session::KernelGraph *graph) {} | |||
| private: | |||
| void AssignStaticMemoryOutput(const session::KernelGraph *graph); | |||
| bool LaunchKernelMod(const session::KernelGraph &graph); | |||
| void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index); | |||
| static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); | |||
| size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); | |||
| void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph); | |||
| @@ -160,6 +163,9 @@ class KernelRuntime { | |||
| void *communication_stream_{nullptr}; | |||
| std::shared_ptr<MemoryManager> mem_manager_{nullptr}; | |||
| std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_; | |||
| std::map<uint32_t, | |||
| std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>> | |||
| graph_kernel_events_map_; | |||
| std::vector<std::shared_ptr<char[]>> buffer_ptrs_ = {}; | |||
| }; | |||
| using KernelRuntimePtr = std::shared_ptr<KernelRuntime>; | |||