Browse Source

!12145 cpu support tensor scalar

From: @huaweib
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
pull/12145/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
e9e5d7e152
2 changed files with 9 additions and 1 deletions
  1. +4
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc
  2. +5
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc

+ 4
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc View File

@@ -24,6 +24,10 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (src1_shape.size() == 0 && src0_shape.size() == 0) {
src0_shape.insert(src0_shape.begin(), 1);
src1_shape.insert(src1_shape.begin(), 1);
}
if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) {
MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs " MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs "
<< src1_shape.size(); << src1_shape.size();


+ 5
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc View File

@@ -133,7 +133,11 @@ dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::d


dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &shape) { dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &shape) {
dnnl::memory::dims dims; dnnl::memory::dims dims;
dims.insert(dims.end(), shape.begin(), shape.end());
if (shape.size() == 0) {
dims.insert(dims.end(), 1);
} else {
dims.insert(dims.end(), shape.begin(), shape.end());
}
dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims);
dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag);
return mem_desc; return mem_desc;


Loading…
Cancel
Save