From 1ca333a3e53b9fa5653e7eae7d43b2ff3ba58b45 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Thu, 5 Nov 2020 21:03:09 +0800 Subject: [PATCH] fix cpu kernel select --- .../cpu/arithmetic_cpu_kernel.cc | 6 +-- .../cpu/arithmetic_self_cpu_kernel.cc | 2 +- .../backend/kernel_compiler/cpu/cpu_kernel.cc | 5 ++- .../cpu/eltwise_grad_cpu_kernel.cc | 6 +-- .../runtime/device/cpu/cpu_kernel_runtime.cc | 5 ++- .../runtime/device/cpu/kernel_select_cpu.cc | 38 +++++++++++++++++++ 6 files changed, 53 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index 011cf16abd..2e7e0ec5f1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -132,14 +132,14 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool ArithmeticCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { - if (dtype_ == kNumberTypeInt32) { + if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32) { + } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { LaunchKernel(inputs, outputs); } else if (dtype_ == kNumberTypeInt64) { LaunchKernel(inputs, outputs); } else { - MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_); + MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index fdfb48481f..c5d71bc711 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -56,7 +56,7 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector &inpu } else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) { LaunchKernel(inputs, outputs); } else { - MS_LOG(EXCEPTION) << "Only support float32, int32, but actual data type is " << TypeIdLabel(dtype_); + MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc index fb9398e7c4..0530e048bd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc @@ -20,8 +20,9 @@ namespace kernel { void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t type_size = sizeof(float); for (size_t input_index = 0; input_index < input_num; ++input_index) { + TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index); + size_t type_size = GetTypeByte(TypeIdToType(type_id)); std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); size_t tensor_size = shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); @@ -29,6 +30,8 @@ void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); for (size_t output_index = 0; output_index < output_num; ++output_index) { + TypeId type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, output_index); + size_t type_size = GetTypeByte(TypeIdToType(type_id)); std::vector shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); size_t tensor_size = shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc index f47c5c2a59..e7a91461ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc @@ -123,14 +123,14 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool EltWiseGradCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { - if (dtype_ == kNumberTypeInt32) { + if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32) { + } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { LaunchKernel(inputs, outputs); } else if (dtype_ == kNumberTypeInt64) { LaunchKernel(inputs, outputs); } else { - MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_); + MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; } return true; } diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index bccd2f4af3..b800a1b3c8 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -54,7 +54,10 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph } auto tensor = node_value->cast(); MS_EXCEPTION_IF_NULL(tensor); - TypeId output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); + } size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); ShapeVector data_shape = tensor->shape(); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc index bdb2abc19b..89257b304b 100644 --- a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc @@ -205,6 +205,37 @@ void SetKernelBuildInfo(const std::vector &input_formats, const std builder->SetOutputsDeviceType(output_types); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); } + +void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector &input_types, + const std::vector &infer_output_types) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + std::stringstream operator_info; + operator_info << "Operator[" << kernel_name << "] "; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num > 0) { + operator_info << " input("; + for (size_t i = 0; i < input_num; ++i) { + operator_info << TypeIdLabel(input_types[i]); + if (i != input_num - 1) { + operator_info << ","; + } + } + operator_info << ") "; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num > 0) { + operator_info << "output("; + for (size_t i = 0; i < output_num; ++i) { + operator_info << TypeIdLabel(infer_output_types[i]); + if (i != output_num - 1) { + operator_info << ","; + } + } + operator_info << ") "; + } + operator_info << "is not support."; + MS_LOG(EXCEPTION) << operator_info.str(); +} } // namespace bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, const std::vector &kernel_attrs, const std::vector &input_formats, @@ -275,10 +306,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) { std::pair matched = std::make_pair(false, false); if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) { + if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") { + KernelNotSupportException(kernel_node, input_types, infer_output_types); + } matched = std::make_pair(false, false); SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, false); + if (!matched.first) { + KernelNotSupportException(kernel_node, input_types, infer_output_types); + } } + if (selected_kernel_attr.GetInputSize() > 0 && (matched.first || input_types.size() == input_not_cnode_indexes.size())) { MS_LOG(INFO) << "Input format and dtype is matched";