Browse Source

fix bug of Conv2DBackpropFilter

r1.7
buxue 4 years ago
parent
commit
c24e9c3e93
2 changed files with 18 additions and 19 deletions
  1. +4
    -8
      mindspore/ccsrc/plugin/device/gpu/kernel/nn/conv2d_grad_filter_gpu_kernel.h
  2. +14
    -11
      mindspore/python/mindspore/ops/operations/_grad_ops.py

+ 4
- 8
mindspore/ccsrc/plugin/device/gpu/kernel/nn/conv2d_grad_filter_gpu_kernel.h View File

@@ -38,9 +38,9 @@ constexpr size_t kBottom2DPadIndex = 1;
constexpr size_t kLeft2DPadIndex = 2;
constexpr size_t kRight2DPadIndex = 3;

constexpr size_t k2DStrideSize = 4;
constexpr size_t kHeight2DStrideIndex = 2;
constexpr size_t kWidth2DStrideIndex = 3;
constexpr size_t k2DStrideSize = 2;
constexpr size_t kHeight2DStrideIndex = 0;
constexpr size_t kWidth2DStrideIndex = 1;

constexpr size_t k2DDilationSize = 4;
constexpr size_t kHeight2DDilationIndex = 2;
@@ -353,13 +353,9 @@ class ConvGradFilterBkwGpuKernelMod : public NativeGpuKernelMod {
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
[](const int64_t &value) { return static_cast<int>(value); });
if (stride_.size() != k2DStrideSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'stride' should be 4, but got "
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'stride' should be 2, but got "
<< stride_.size();
}
if (stride_[0] != 1 || stride_[1] != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'stride' at 0 and 1 axis should be 1, but got "
<< "stride[0]: " << stride_[0] << ", stride[1]: " << stride_[1];
}
if (dilation_.size() != k2DDilationSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'dilation' should be 4, but got "
<< dilation_.size();


+ 14
- 11
mindspore/python/mindspore/ops/operations/_grad_ops.py View File

@@ -17,7 +17,7 @@
import math
from functools import partial
from mindspore._checkparam import _check_3d_int_or_tuple
from .nn_ops import _check_positive_int_or_tuple, _update_attr_by_format
from .nn_ops import _check_positive_int_or_tuple
from .. import signature as sig
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator, Rel
@@ -423,18 +423,21 @@ class Conv2DBackpropFilter(Primitive):
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
self.add_prim_attr("pad", pad)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
self.stride = _update_attr_by_format(self.stride, self.format)
self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.dilation = _update_attr_by_format(self.dilation, self.format)
self.add_prim_attr('dilation', self.dilation)
self.dilation = dilation
if context.get_context("device_target") != "GPU":
if self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True,
ret_four=True)
else:
if isinstance(stride, tuple) and len(stride) == 4:
self.stride = (stride[2], stride[3])
self.dilation = dilation
self.group = group
self.add_prim_attr('groups', group)
self.add_prim_attr('stride', self.stride)
self.add_prim_attr('dilation', self.dilation)
self.add_prim_attr('data_format', self.format)


class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):


Loading…
Cancel
Save