From 07b50f76ab8d6593eb87d5903daaf9f6b6961afc Mon Sep 17 00:00:00 2001 From: linqingke Date: Wed, 10 Mar 2021 16:14:12 +0800 Subject: [PATCH] fix conv2d_grad_inpu infer strides. --- .../ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h | 1 + .../cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc | 13 ++++++++++++- mindspore/ops/operations/nn_ops.py | 4 ++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 8bd3f04e26..3b337781d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -34,6 +34,7 @@ const char KERNEL_SIZE[] = "kernel_size"; const char STRIDE[] = "stride"; const char STRIDES[] = "strides"; const char DILATION[] = "dilation"; +const char FORMAT[] = "format"; const char PAD[] = "pad"; const char PAD_LIST[] = "pad_list"; const char PAD_MODE[] = "pad_mode"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc index 8a61dce8be..fa09853a6d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc @@ -15,6 +15,7 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h" #include +#include #include #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -22,6 +23,7 @@ namespace mindspore { namespace kernel { +const std::map kFormatIndexMap = {{"NCHW", 2}, {"HWCN", 0}, {"NHWC", 1}}; void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::vector src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); @@ -47,7 +49,16 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector dilation_ori; auto stride_me = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); auto dilation_me = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); - (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_ori), + auto format_me = AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + auto iter = kFormatIndexMap.find(format_me); + if (iter == kFormatIndexMap.end()) { + MS_LOG(EXCEPTION) << "OriFormat is " << format_me << ", Please confirm that in {NCHW, HWCN, NHWC}."; + } + size_t h_index = iter->second; + if (stride_me.size() < h_index + 2) { + MS_LOG(EXCEPTION) << "Strides should greater than " << h_index + 1 << ", but got " << stride_me.size(); + } + (void)std::transform(stride_me.begin() + h_index, stride_me.begin() + h_index + 2, std::back_inserter(stride_ori), [](const int64_t &value) { return static_cast(value); }); (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), [](const int64_t &value) { return static_cast(value); }); diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c684894310..a1cc71e5a6 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2028,8 +2028,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer): [dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]] kernel_h = self.kernel_size[0] kernel_w = self.kernel_size[1] - stride_h = self.stride[0] - stride_w = self.stride[1] + stride_h = self.stride[2] + stride_w = self.stride[3] dilation_h = self.dilation[2] dilation_w = self.dilation[3] # default pad mode is valid