Browse Source

fix conv2d_grad_inpu infer strides.

tags/v1.2.0-rc1
linqingke 4 years ago
parent
commit
07b50f76ab
3 changed files with 15 additions and 3 deletions
  1. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h
  2. +12
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc
  3. +2
    -2
      mindspore/ops/operations/nn_ops.py

+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h View File

@@ -34,6 +34,7 @@ const char KERNEL_SIZE[] = "kernel_size";
const char STRIDE[] = "stride"; const char STRIDE[] = "stride";
const char STRIDES[] = "strides"; const char STRIDES[] = "strides";
const char DILATION[] = "dilation"; const char DILATION[] = "dilation";
const char FORMAT[] = "format";
const char PAD[] = "pad"; const char PAD[] = "pad";
const char PAD_LIST[] = "pad_list"; const char PAD_LIST[] = "pad_list";
const char PAD_MODE[] = "pad_mode"; const char PAD_MODE[] = "pad_mode";


+ 12
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc View File

@@ -15,6 +15,7 @@
*/ */
#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h"
#include <string> #include <string>
#include <map>
#include <algorithm> #include <algorithm>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
@@ -22,6 +23,7 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
const std::map<std::string, size_t> kFormatIndexMap = {{"NCHW", 2}, {"HWCN", 0}, {"NHWC", 1}};
void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); std::vector<size_t> src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
@@ -47,7 +49,16 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<int> dilation_ori; std::vector<int> dilation_ori;
auto stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDE); auto stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDE);
auto dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, DILATION); auto dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, DILATION);
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_ori),
auto format_me = AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT);
auto iter = kFormatIndexMap.find(format_me);
if (iter == kFormatIndexMap.end()) {
MS_LOG(EXCEPTION) << "OriFormat is " << format_me << ", Please confirm that in {NCHW, HWCN, NHWC}.";
}
size_t h_index = iter->second;
if (stride_me.size() < h_index + 2) {
MS_LOG(EXCEPTION) << "Strides should greater than " << h_index + 1 << ", but got " << stride_me.size();
}
(void)std::transform(stride_me.begin() + h_index, stride_me.begin() + h_index + 2, std::back_inserter(stride_ori),
[](const int64_t &value) { return static_cast<int>(value); }); [](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori),
[](const int64_t &value) { return static_cast<int>(value); }); [](const int64_t &value) { return static_cast<int>(value); });


+ 2
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -2028,8 +2028,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]] [dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
kernel_h = self.kernel_size[0] kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1] kernel_w = self.kernel_size[1]
stride_h = self.stride[0]
stride_w = self.stride[1]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2] dilation_h = self.dilation[2]
dilation_w = self.dilation[3] dilation_w = self.dilation[3]
# default pad mode is valid # default pad mode is valid


Loading…
Cancel
Save