Browse Source

fix conv2d 2

tags/v0.3.0-alpha
zhaozhenlong 6 years ago
parent
commit
366b6d6803
1 changed files with 27 additions and 80 deletions
  1. +27
    -80
      mindspore/ops/operations/_grad_ops.py

+ 27
- 80
mindspore/ops/operations/_grad_ops.py View File

@@ -137,9 +137,9 @@ class ConcatOffset(PrimitiveWithInfer):
return out


class Conv2DBackpropInput(PrimitiveWithInfer):
class Conv2DBackpropFilter(PrimitiveWithInfer):
"""
Computes the gradients of convolution with respect to the input.
Computes the gradients of convolution with respect to the filter.

Args:
out_channel (int): The dimensionality of the output space.
@@ -147,9 +147,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
pad (int): The pad value to fill. Default: 0.
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
2 deconvolution, 3 depthwise convolution. Default: 1.
stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1.
dilation (Union[int. tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
2 deconvolution, 3 depthwise convolution. Default: 1.
stride (tuple): The stride to apply conv filter. Default: (1, 1).
dilation (tuple): Specifies the dilation rate to use for dilated convolution. Default: (1, 1, 1, 1).
group (int): Splits input into groups. Default: 1.

Returns:
@@ -162,89 +162,36 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
kernel_size,
pad_mode="valid",
pad=0,
pad_list=None,
pad_list=(0, 0, 0, 0),
mode=1,
stride=1,
dilation=1,
stride=(1, 1),
dilation=(1, 1, 1, 1),
group=1):
"""init Conv2DBackpropInput"""
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT)
self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple))
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \
(not isinstance(self.kernel_size[1], int)) or \
self.kernel_size[0] < 1 or self.kernel_size[1] < 1:
raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or "
f"a tuple of two positive int numbers, but got {kernel_size}")
self.stride = validator.check_type('stride', stride, (int, tuple))
if isinstance(stride, int):
self.stride = (stride, stride)
elif isinstance(stride, tuple) and len(stride) == 4:
self.stride = (stride[2], stride[3])
if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \
self.stride[0] < 1 or self.stride[1] < 1:
raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or "
f"a tuple of two or four positive int numbers, but got {stride}")
self.add_prim_attr('stride', self.stride)
self.dilation = validator.check_type('dilation', dilation, (tuple, int))
if isinstance(dilation, int):
self.dilation = (1, 1, dilation, dilation)
elif len(dilation) == 2:
self.dilation = (1, 1, dilation[0], dilation[1])
if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \
(not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \
(not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \
(not isinstance(self.dilation[3], int) or self.dilation[3] < 1):
raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or "
f"a tuple of two or four positive int numbers, but got {dilation}")
self.add_prim_attr('dilation', self.dilation)
validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool))
validator.equal('type of pad', type(pad), 'int', isinstance(pad, int))
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'])
self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad)
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ)
self.group = validator.check_integer('group', group, 0, Rel.GT)
"""init Convolution"""
self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
self.out_channel = out_channel
self.kernel_size = kernel_size
self.mode = mode
pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode)
self.pad = pad
if isinstance(stride, tuple) and len(stride) == 4:
self.stride = (stride[2], stride[3])
self.add_prim_attr('stride', self.stride)
self.dilation = dilation
self.group = group
self.add_prim_attr('data_format', "NCHW")
if pad_list:
self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list)

def __infer__(self, doutput, w, x_size):
x_size_v = x_size['value']
validator.check_type('x_size', x_size_v, [tuple])
for i, dim_len in enumerate(x_size_v):
validator.check_type("x_size[%d]" % i, dim_len, [int])
validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32])
validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype'])

# infer shape
dout_shape = doutput['shape']
kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1]
stride_h = self.stride[0]
stride_w = self.stride[1]
# default pad mode is valid
pad_list = (0, 0, 0, 0)
if self.pad_list:
pad_list = tuple(self.pad_list)
elif self.pad_mode == "SAME":
pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + kernel_h - x_size_v[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top

pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + kernel_w - x_size_v[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
elif self.pad_mode == 'PAD':
pad_list = (self.pad,) * 4
self.add_prim_attr('pad_list', pad_list)
def __infer__(self, doutput, x, w_size):
w_size_v = w_size['value']
validator.check_value_type('w_size', w_size_v, [tuple], self.name)
for i, dim_len in enumerate(w_size_v):
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
out = {
'value': None,
'shape': x_size_v,
'shape': w_size_v,
'dtype': doutput['dtype'],
}
return out


Loading…
Cancel
Save