Browse Source

Fix Conv2D op group attr problem.

tags/v1.1.0
liangchenghui 5 years ago
parent
commit
3ed33bb2aa
4 changed files with 6 additions and 3 deletions
  1. +1
    -1
      mindspore/nn/layer/pooling.py
  2. +2
    -1
      mindspore/ops/_op_impl/tbe/conv2d.py
  3. +1
    -1
      mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py
  4. +2
    -0
      mindspore/ops/operations/nn_ops.py

+ 1
- 1
mindspore/nn/layer/pooling.py View File

@@ -354,12 +354,12 @@ class AvgPool1d(_PoolNd):
kernel_size=1, kernel_size=1,
stride=1, stride=1,
pad_mode="valid"): pad_mode="valid"):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name) validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
self.kernel_size = (1, kernel_size) self.kernel_size = (1, kernel_size)
self.stride = (1, stride) self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size, self.avg_pool = P.AvgPool(ksize=self.kernel_size,


+ 2
- 1
mindspore/ops/_op_impl/tbe/conv2d.py View File

@@ -27,7 +27,8 @@ conv2d_op_info = TBERegOp("Conv2D") \
.attr("stride", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \ .attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \
.attr("offset_a", "optional", "int", "all") \
.attr("groups", "optional", "int", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all") \
.input(1, "filter", False, "required", "all") \ .input(1, "filter", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \ .input(2, "bias", False, "optional", "all") \


+ 1
- 1
mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py View File

@@ -27,7 +27,7 @@ conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \
.attr("stride", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \ .attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \
.attr("group", "optional", "int", "all") \
.attr("groups", "optional", "int", "all") \
.attr("data_format", "optional", "str", "all") \ .attr("data_format", "optional", "str", "all") \
.input(0, "out_backprop", False, "required", "all") \ .input(0, "out_backprop", False, "required", "all") \
.input(1, "filter", False, "required", "all") \ .input(1, "filter", False, "required", "all") \


+ 2
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -1043,6 +1043,7 @@ class Conv2D(PrimitiveWithInfer):
self.add_prim_attr('data_format', self.format) self.add_prim_attr('data_format', self.format)
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name) self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.add_prim_attr('offset_a', 0) self.add_prim_attr('offset_a', 0)


def infer_shape(self, x_shape, w_shape, b_shape=None): def infer_shape(self, x_shape, w_shape, b_shape=None):
@@ -1587,6 +1588,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('pad_mode', pad_mode) self.add_prim_attr('pad_mode', pad_mode)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.group = validator.check_positive_int(group, 'group', self.name) self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC": if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.") raise ValueError("NHWC format only support in GPU target.")


Loading…
Cancel
Save