From a354d3226ae4a98ee4ce16701d8c09940926a267 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Thu, 4 Feb 2021 20:37:05 +0800 Subject: [PATCH] cpu support tensor scalar --- .../kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc | 4 ++++ .../backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc index 402527bc65..76ff7989f4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc @@ -24,6 +24,10 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + if (src1_shape.size() == 0 && src0_shape.size() == 0) { + src0_shape.insert(src0_shape.begin(), 1); + src1_shape.insert(src1_shape.begin(), 1); + } if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs " << src1_shape.size(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc index 085d618ee2..163f0c1387 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc @@ -133,7 +133,11 @@ dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::d dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector &shape) { dnnl::memory::dims dims; - dims.insert(dims.end(), shape.begin(), shape.end()); + if (shape.size() == 0) { + dims.insert(dims.end(), 1); + } else { + dims.insert(dims.end(), shape.begin(), shape.end()); + } dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); return mem_desc;