From df7281c367a90d0839f5240289dcfcc75c1324a6 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Wed, 20 May 2020 19:11:13 +0800 Subject: [PATCH] cpu kernel support mutil dtype --- .../ccsrc/device/cpu/kernel_select_cpu.cc | 24 ++++++++++++------- .../ccsrc/kernel/cpu/cpu_kernel_factory.h | 8 ++++--- .../ccsrc/kernel/cpu/reshape_cpu_kernel.h | 8 +++++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc index f7ccc443aa..6972a58125 100644 --- a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc +++ b/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc @@ -59,6 +59,7 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vectoremplace_back(input_index); + dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); } else { dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); } @@ -84,22 +85,25 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector< const std::vector &input_types, const std::vector &input_not_cnode_indexes) { if (kernel_attr.GetInputSize() != input_types.size()) { - MS_LOG(ERROR) << "Output num is not equal!"; + MS_LOG(ERROR) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); return false; } auto input_num = input_types.size(); for (size_t i = 0; i < input_num; ++i) { bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), [i](size_t index) { return index == i; }); - if (is_not_cnode_idx) { + bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); + if (have_cnode_input && is_not_cnode_idx) { continue; } if (kernel_attr.GetInputAttr(i).first != input_types[i]) { - MS_LOG(ERROR) << "reg dtype=" << kernel_attr.GetInputAttr(i).first << ", input dtype=" << input_types[i]; + MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first + << ", actual input dtype:" << input_types[i]; return false; } if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { - MS_LOG(ERROR) << "reg format=" << kernel_attr.GetInputAttr(i).second << ", input format=" << input_formats[i]; + MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second + << ", actual input format:" << input_formats[i]; return false; } } @@ -114,17 +118,19 @@ void SetKernelInfo(const CNodePtr &kernel_node) { std::vector output_formats; std::vector output_types; + MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node); GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); auto kernel_attrs = kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); - for (auto &kernel_attr : kernel_attrs) { - if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { - GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); - UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); + for (size_t index = 0; index < kernel_attrs.size(); ++index) { + if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) { + MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; + GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types); + UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node); for (auto &input_index : input_not_cnode_indexes) { - input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; + input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first; } break; } diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h index 17ea04070c..4a10c0ba5f 100644 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h +++ b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h @@ -55,10 +55,12 @@ class CPUKernelRegistrar { ~CPUKernelRegistrar() = default; }; -#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) \ +#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS) +#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) +#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \ static_assert(std::is_base_of::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_reg(#OPNAME, ATTR, \ - []() { return std::make_shared(); }); + static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ + []() { return std::make_shared(); }); #define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \ static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ diff --git a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h index 1d5a8fa203..6ca746f4ac 100644 --- a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h @@ -35,10 +35,18 @@ class ReshapeCPUKernel : public CPUKernel { MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); + MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); + MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReshapeCPUKernel); +MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); } // namespace kernel } // namespace mindspore