From dd4397283997f2fee7edf8d2166ea4f24be3a971 Mon Sep 17 00:00:00 2001 From: looop5 Date: Sat, 27 Mar 2021 15:40:38 +0800 Subject: [PATCH] infer processor from ms_context inside function CreateCNode --- .../graph_kernel/graph_kernel_helper.cc | 16 +++++++++++++++- .../optimizer/graph_kernel/graph_kernel_helper.h | 1 + 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 3de37c8798..bc502647cd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -33,6 +33,7 @@ #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/action.h" #include "vm/segment_runner.h" +#include "utils/ms_context.h" #if ENABLE_GPU #include "runtime/device/gpu/kernel_info_setter.h" #endif @@ -796,6 +797,19 @@ std::vector GetReduceAxis(const AnfNodePtr &node) { return axis; } +kernel::Processor GetProcessorFromContext() { + kernel::Processor processor = kernel::Processor::UNKNOWN; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto device_info = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + if (device_info == kGPUDevice) { + processor = kernel::Processor::CUDA; + } else if (device_info == kAscendDevice) { + processor = kernel::Processor::AICORE; + } + return processor; +} + CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info) { // Limitation: 1. Node's attributes should be set out of this function; 2. only one output. MS_EXCEPTION_IF_NULL(out_info.type); @@ -852,7 +866,7 @@ CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr & info_builder.SetInputsDeviceType(input_types); info_builder.SetOutputsFormat(output_formats); info_builder.SetOutputsDeviceType(output_types); - info_builder.SetProcessor(kernel::Processor::CUDA); + info_builder.SetProcessor(GetProcessorFromContext()); info_builder.SetKernelType(KernelType::AKG_KERNEL); info_builder.SetFusionType(kernel::FusionType::OPAQUE); auto selected_info = info_builder.Build(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 0e4ead2a4d..3bb52e3853 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -87,6 +87,7 @@ TypePtr GetType(const AnfNodePtr &node); ShapeVector GetShape(const AnfNodePtr &node); ShapeVector GetDeviceShape(const AnfNodePtr &node); std::vector GetReduceAxis(const AnfNodePtr &node); +kernel::Processor GetProcessorFromContext(); CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);