From d608ffdaf20e59e7df2f261fc372d5dfd000a75e Mon Sep 17 00:00:00 2001 From: dayschan Date: Mon, 14 Feb 2022 16:07:39 +0800 Subject: [PATCH] explicitly check if graph_kernel is enabled in pynative mode. --- .../runtime/graph_scheduler/graph_compiler.cc | 5 +-- .../ccsrc/utils/context/graph_kernel_flags.cc | 37 ++++++++++++------- .../ccsrc/utils/context/graph_kernel_flags.h | 3 ++ 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc index e385856e68..c7fb7b87b7 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc @@ -377,9 +377,8 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod GraphId graph_id; if (run_in_pynative) { MS_EXCEPTION_IF_NULL(session_); - // Graphkernel not support pynative mode now, so when users open graphkernel in pynative mode - // should print a warning log to reminder users by using GetInstance func. - (void)graphkernel::GraphKernelFlags::GetInstance(); + // Graphkernel does not support pynative mode now, print a warning here. + graphkernel::GraphKernelFlags::GetInstance().CheckSupport(); session_->InitAllBucket(graph, device_context); graph_id = graph->graph_id(); } else { diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc index cfcbdfe1ce..e721202ccf 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc @@ -171,6 +171,28 @@ std::pair GraphKernelFlags::GetGraphKernelContext() { return std::make_pair(flags, enable_context); } +void GraphKernelFlags::CheckSupport() const { +#ifndef MSLITE_ENABLE_GRAPH_KERNEL + if (IsEnableGraphKernel()) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + if (context->get_param(MS_CTX_EXECUTION_MODE) != kGraphMode) { + MS_LOG(WARNING) << "GraphKernel only support GRAPH_MODE."; + const_cast(this)->opt_level = OptLevel_0; + return; + } +#ifndef USE_LLVM + auto is_cpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kCPUDevice); + if (is_cpu) { + MS_LOG(WARNING) << "GraphKernel is not usable without LLVM on cpu platform."; + const_cast(this)->opt_level = OptLevel_0; + return; + } +#endif + } +#endif +} + void GraphKernelFlags::Refresh() { auto flag_map = ParseFlags(flags_cache_); RegisterFlags(&flag_map); @@ -179,28 +201,15 @@ void GraphKernelFlags::Refresh() { } #ifndef MSLITE_ENABLE_GRAPH_KERNEL if (IsEnableGraphKernel()) { + CheckSupport(); auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - if (context->get_param(MS_CTX_EXECUTION_MODE) != kGraphMode) { - MS_LOG(WARNING) << "GraphKernel only support GRAPH_MODE"; - opt_level = OptLevel_0; - } - - // check whether on ascend open graphkernel, if open, may cause error, reminder - // users to close this feature. auto is_ascend = (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice); if (is_ascend) { MS_LOG(WARNING) << "GraphKernel on Ascend is experimental, please disable it if you meet some compiling or running error. For " "more details, please refer to 'mindspore.context' at https://www.mindspore.cn."; } -#ifndef USE_LLVM - auto is_cpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kCPUDevice); - if (is_cpu) { - MS_LOG(WARNING) << "GraphKernel is not usable without LLVM on cpu platform"; - opt_level = OptLevel_0; - } -#endif } #endif // If enable graphkernel, Dump flags so that people can check the setting. diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.h b/mindspore/ccsrc/utils/context/graph_kernel_flags.h index 53418349b7..afdcab17f8 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.h +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.h @@ -52,6 +52,9 @@ class GraphKernelFlags { // Check whether graph_kernel is enabled bool IsEnableGraphKernel() const { return opt_level > OptLevel_0; } + // Check whether GraphKernel supports current situation. + void CheckSupport() const; + GraphKernelFlags(const GraphKernelFlags &flags) = delete; GraphKernelFlags(GraphKernelFlags &&flags) = delete; void operator=(const GraphKernelFlags &flags) = delete;