| @@ -354,12 +354,12 @@ class AvgPool1d(_PoolNd): | |||||
| kernel_size=1, | kernel_size=1, | ||||
| stride=1, | stride=1, | ||||
| pad_mode="valid"): | pad_mode="valid"): | ||||
| super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) | |||||
| validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) | validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) | ||||
| validator.check_value_type('stride', stride, [int], self.cls_name) | validator.check_value_type('stride', stride, [int], self.cls_name) | ||||
| self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) | self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) | ||||
| validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) | validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) | ||||
| validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) | validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) | ||||
| super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) | |||||
| self.kernel_size = (1, kernel_size) | self.kernel_size = (1, kernel_size) | ||||
| self.stride = (1, stride) | self.stride = (1, stride) | ||||
| self.avg_pool = P.AvgPool(ksize=self.kernel_size, | self.avg_pool = P.AvgPool(ksize=self.kernel_size, | ||||
| @@ -27,7 +27,8 @@ conv2d_op_info = TBERegOp("Conv2D") \ | |||||
| .attr("stride", "required", "listInt", "all") \ | .attr("stride", "required", "listInt", "all") \ | ||||
| .attr("pad_list", "required", "listInt", "all") \ | .attr("pad_list", "required", "listInt", "all") \ | ||||
| .attr("dilation", "required", "listInt", "all") \ | .attr("dilation", "required", "listInt", "all") \ | ||||
| .attr("offset_a", "optional", "int", "all") \ | |||||
| .attr("groups", "optional", "int", "all") \ | |||||
| .attr("data_format", "optional", "str", "all") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "filter", False, "required", "all") \ | .input(1, "filter", False, "required", "all") \ | ||||
| .input(2, "bias", False, "optional", "all") \ | .input(2, "bias", False, "optional", "all") \ | ||||
| @@ -27,7 +27,7 @@ conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \ | |||||
| .attr("stride", "required", "listInt", "all") \ | .attr("stride", "required", "listInt", "all") \ | ||||
| .attr("pad_list", "required", "listInt", "all") \ | .attr("pad_list", "required", "listInt", "all") \ | ||||
| .attr("dilation", "required", "listInt", "all") \ | .attr("dilation", "required", "listInt", "all") \ | ||||
| .attr("group", "optional", "int", "all") \ | |||||
| .attr("groups", "optional", "int", "all") \ | |||||
| .attr("data_format", "optional", "str", "all") \ | .attr("data_format", "optional", "str", "all") \ | ||||
| .input(0, "out_backprop", False, "required", "all") \ | .input(0, "out_backprop", False, "required", "all") \ | ||||
| .input(1, "filter", False, "required", "all") \ | .input(1, "filter", False, "required", "all") \ | ||||
| @@ -1043,6 +1043,7 @@ class Conv2D(PrimitiveWithInfer): | |||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) | self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) | ||||
| self.group = validator.check_positive_int(group, 'group', self.name) | self.group = validator.check_positive_int(group, 'group', self.name) | ||||
| self.add_prim_attr('groups', self.group) | |||||
| self.add_prim_attr('offset_a', 0) | self.add_prim_attr('offset_a', 0) | ||||
| def infer_shape(self, x_shape, w_shape, b_shape=None): | def infer_shape(self, x_shape, w_shape, b_shape=None): | ||||
| @@ -1587,6 +1588,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||||
| self.add_prim_attr('pad_mode', pad_mode) | self.add_prim_attr('pad_mode', pad_mode) | ||||
| self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | ||||
| self.group = validator.check_positive_int(group, 'group', self.name) | self.group = validator.check_positive_int(group, 'group', self.name) | ||||
| self.add_prim_attr('groups', self.group) | |||||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | ||||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | if context.get_context("device_target") != "GPU" and self.format == "NHWC": | ||||
| raise ValueError("NHWC format only support in GPU target.") | raise ValueError("NHWC format only support in GPU target.") | ||||