| @@ -18,11 +18,12 @@ from mindspore.ops import operations as P | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore._checkparam import ParamValidator as validator, Rel | from mindspore._checkparam import ParamValidator as validator, Rel | ||||
| from mindspore._checkparam import Validator | |||||
| from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative | from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative | ||||
| from mindspore._extends import cell_attr_register | from mindspore._extends import cell_attr_register | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| __all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d'] | |||||
| __all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d', 'Conv1d', 'Conv1dTranspose'] | |||||
| class _Conv(Cell): | class _Conv(Cell): | ||||
| """ | """ | ||||
| @@ -241,6 +242,174 @@ class Conv2d(_Conv): | |||||
| return s | return s | ||||
| class Conv1d(_Conv): | |||||
| r""" | |||||
| 1D convolution layer. | |||||
| Applies a 1D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, W_{in})`, | |||||
| where :math:`N` is batch size and :math:`C_{in}` is channel number. For each batch of shape | |||||
| :math:`(C_{in}, W_{in})`, the formula is defined as: | |||||
| .. math:: | |||||
| out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j, | |||||
| where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges | |||||
| from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th | |||||
| filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice | |||||
| of kernel and it has shape :math:`(\text{ks_w})`, where :math:`\text{ks_w}` are width of the convolution kernel. | |||||
| The full kernel has shape :math:`(C_{out}, C_{in} // \text{group}, \text{ks_w})`, where group is the group number | |||||
| to split the input in the channel dimension. | |||||
| If the 'pad_mode' is set to be "valid", the output width will be | |||||
| :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - | |||||
| (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. | |||||
| The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition | |||||
| <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_. | |||||
| Args: | |||||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||||
| kernel_size (int): The data type is int. Specifies the | |||||
| width of the 1D convolution window. | |||||
| stride (int): The distance of kernel moving, an int number that represents | |||||
| the width of movement. Default: 1. | |||||
| pad_mode (str): Specifies padding mode. The optional values are | |||||
| "same", "valid", "pad". Default: "same". | |||||
| - same: Adopts the way of completion. Output width will be the same as the input. | |||||
| Total number of padding will be calculated for horizontal | |||||
| direction and evenly distributed to left and right if possible. Otherwise, the | |||||
| last extra padding will be done from the bottom and the right side. If this mode is set, `padding` | |||||
| must be 0. | |||||
| - valid: Adopts the way of discarding. The possibly largest width of output will be return | |||||
| without padding. Extra pixels will be discarded. If this mode is set, `padding` | |||||
| must be 0. | |||||
| - pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input | |||||
| Tensor borders. `padding` should be greater than or equal to 0. | |||||
| padding (int): Implicit paddings on both sides of the input. Default: 0. | |||||
| dilation (int): The data type is int. Specifies the dilation rate | |||||
| to use for dilated convolution. If set to be :math:`k > 1`, there will | |||||
| be :math:`k - 1` pixels skipped for each sampling location. Its value should | |||||
| be greater or equal to 1 and bounded by the height and width of the | |||||
| input. Default: 1. | |||||
| group (int): Split filter into groups, `in_ channels` and `out_channels` should be | |||||
| divisible by the number of groups. Default: 1. | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||||
| It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, | |||||
| values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well | |||||
| as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' | |||||
| and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of | |||||
| Initializer for more details. Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible | |||||
| Initializer and string are the same as 'weight_init'. Refer to the values of | |||||
| Initializer for more details. Default: 'zeros'. | |||||
| Inputs: | |||||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`. | |||||
| Outputs: | |||||
| Tensor of shape :math:`(N, C_{out}, W_{out})`. | |||||
| Examples: | |||||
| >>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal') | |||||
| >>> input = Tensor(np.ones([1, 120, 640]), mindspore.float32) | |||||
| >>> net(input).shape | |||||
| (1, 240, 640) | |||||
| """ | |||||
| @cell_attr_register | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| pad_mode='same', | |||||
| padding=0, | |||||
| dilation=1, | |||||
| group=1, | |||||
| has_bias=False, | |||||
| weight_init='normal', | |||||
| bias_init='zeros'): | |||||
| Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name) | |||||
| Validator.check_value_type("stride", stride, [int], self.cls_name) | |||||
| Validator.check_value_type("padding", padding, [int], self.cls_name) | |||||
| Validator.check_value_type("dilation", dilation, [int], self.cls_name) | |||||
| Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | |||||
| Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | |||||
| Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) | |||||
| Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | |||||
| kernel_size = (1, kernel_size) | |||||
| stride = (1, stride) | |||||
| dilation = (1, dilation) | |||||
| super(Conv1d, self).__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| pad_mode, | |||||
| padding, | |||||
| dilation, | |||||
| group, | |||||
| has_bias, | |||||
| weight_init, | |||||
| bias_init) | |||||
| self.padding = (0, 0, padding, padding) | |||||
| self.conv2d = P.Conv2D(out_channel=self.out_channels, | |||||
| kernel_size=self.kernel_size, | |||||
| mode=1, | |||||
| pad_mode=self.pad_mode, | |||||
| pad=self.padding, | |||||
| stride=self.stride, | |||||
| dilation=self.dilation, | |||||
| group=self.group) | |||||
| self.bias_add = P.BiasAdd() | |||||
| if pad_mode not in ('valid', 'same', 'pad'): | |||||
| raise ValueError('Attr \'pad_mode\' of \'Conv1d\' Op passed ' | |||||
| + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.squeeze = P.Squeeze(2) | |||||
| self.shape = P.Shape() | |||||
| def construct(self, x): | |||||
| x_shape = self.shape(x) | |||||
| if len(x_shape) == 3: | |||||
| x = self.expand_dims(x, 2) | |||||
| output = self.conv2d(x, self.weight) | |||||
| if self.has_bias: | |||||
| output = self.bias_add(output, self.bias) | |||||
| if len(x_shape) == 3: | |||||
| output = self.squeeze(output) | |||||
| return output | |||||
| def extend_repr(self): | |||||
| s = 'input_channels={}, output_channels={}, kernel_size={},' \ | |||||
| 'stride={}, pad_mode={}, padding={}, dilation={}, ' \ | |||||
| 'group={}, has_bias={},' \ | |||||
| 'weight_init={}, bias_init={}'.format( | |||||
| self.in_channels, | |||||
| self.out_channels, | |||||
| self.kernel_size, | |||||
| self.stride, | |||||
| self.pad_mode, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.group, | |||||
| self.has_bias, | |||||
| self.weight, | |||||
| self.bias) | |||||
| if self.has_bias: | |||||
| s += ', bias={}'.format(self.bias) | |||||
| return s | |||||
| class Conv2dTranspose(_Conv): | class Conv2dTranspose(_Conv): | ||||
| r""" | r""" | ||||
| 2D transposed convolution layer. | 2D transposed convolution layer. | ||||
| @@ -400,6 +569,181 @@ class Conv2dTranspose(_Conv): | |||||
| return s | return s | ||||
| class Conv1dTranspose(_Conv): | |||||
| r""" | |||||
| 1D transposed convolution layer. | |||||
| Compute a 1D transposed convolution, which is also know as a deconvolution | |||||
| (although it is not actual deconvolution). | |||||
| Input is typically of shape :math:`(N, C, W)`, where :math:`N` is batch size and :math:`C` is channel number. | |||||
| Args: | |||||
| in_channels (int): The number of channels in the input space. | |||||
| out_channels (int): The number of channels in the output space. | |||||
| kernel_size (int): int, which specifies the width of the 1D convolution window. | |||||
| stride (int): The distance of kernel moving, an int number that represents | |||||
| the width of movement. Default: 1. | |||||
| pad_mode (str): Select the mode of the pad. The optional values are | |||||
| "pad", "same", "valid". Default: "same". | |||||
| - pad: Implicit paddings on both sides of the input. | |||||
| - same: Adopted the way of completion. | |||||
| - valid: Adopted the way of discarding. | |||||
| padding (int): Implicit paddings on both sides of the input. Default: 0. | |||||
| dilation (int): The data type is int. Specifies the dilation rate | |||||
| to use for dilated convolution. If set to be :math:`k > 1`, there will | |||||
| be :math:`k - 1` pixels skipped for each sampling location. Its value should | |||||
| be greater or equal to 1 and bounded by the width of the | |||||
| input. Default: 1. | |||||
| group (int): Split filter into groups, `in_channels` and `out_channels` should be | |||||
| divisible by the number of groups. This is not support for Davinci devices when group > 1. Default: 1. | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||||
| It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, | |||||
| values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well | |||||
| as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' | |||||
| and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of | |||||
| Initializer for more details. Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible | |||||
| Initializer and string are the same as 'weight_init'. Refer to the values of | |||||
| Initializer for more details. Default: 'zeros'. | |||||
| Inputs: | |||||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`. | |||||
| Outputs: | |||||
| Tensor of shape :math:`(N, C_{out}, W_{out})`. | |||||
| Examples: | |||||
| >>> net = nn.Conv1dTranspose(3, 64, 4, has_bias=False, weight_init='normal') | |||||
| >>> input = Tensor(np.ones([1, 3, 50]), mindspore.float32) | |||||
| >>> net(input) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| pad_mode='same', | |||||
| padding=0, | |||||
| dilation=1, | |||||
| group=1, | |||||
| has_bias=False, | |||||
| weight_init='normal', | |||||
| bias_init='zeros'): | |||||
| Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name) | |||||
| Validator.check_value_type("stride", stride, [int], self.cls_name) | |||||
| Validator.check_value_type("padding", padding, [int], self.cls_name) | |||||
| Validator.check_value_type("dilation", dilation, [int], self.cls_name) | |||||
| Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | |||||
| Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | |||||
| Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) | |||||
| Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | |||||
| kernel_size = (1, kernel_size) | |||||
| stride = (1, stride) | |||||
| dilation = (1, dilation) | |||||
| # out_channels and in_channels swap. | |||||
| # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel, | |||||
| # then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel. | |||||
| super(Conv1dTranspose, self).__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| pad_mode, | |||||
| padding, | |||||
| dilation, | |||||
| group, | |||||
| has_bias, | |||||
| weight_init, | |||||
| bias_init, | |||||
| transposed=True) | |||||
| self.padding = (0, 0, padding, padding) | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.shape = P.Shape() | |||||
| if pad_mode not in ('valid', 'same', 'pad'): | |||||
| raise ValueError('Attr \'pad_mode\' of \'Conv1dTranspose\' Op passed ' | |||||
| + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') | |||||
| self.is_valid = self.pad_mode == 'valid' | |||||
| self.is_same = self.pad_mode == 'same' | |||||
| self.is_pad = self.pad_mode == 'pad' | |||||
| if check_bool(has_bias): | |||||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') | |||||
| # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel. | |||||
| self.conv2d_transpose = P.Conv2DBackpropInput(out_channel=in_channels, | |||||
| kernel_size=kernel_size, | |||||
| mode=1, | |||||
| pad_mode=pad_mode, | |||||
| pad=self.padding, | |||||
| stride=stride, | |||||
| dilation=dilation, | |||||
| group=group) | |||||
| self.bias_add = P.BiasAdd() | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.squeeze = P.Squeeze(2) | |||||
| def set_strategy(self, strategy): | |||||
| self.conv2d_transpose.set_strategy(strategy) | |||||
| return self | |||||
| def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding): | |||||
| """Calculate the width and height of output.""" | |||||
| length = 0 | |||||
| filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) | |||||
| if self.is_valid: | |||||
| if filter_size - stride_size > 0: | |||||
| length = input_length * stride_size + filter_size - stride_size | |||||
| else: | |||||
| length = input_length * stride_size | |||||
| elif self.is_same: | |||||
| length = input_length * stride_size | |||||
| elif self.is_pad: | |||||
| length = input_length * stride_size - padding + filter_size - stride_size | |||||
| return length | |||||
| def construct(self, x): | |||||
| x_shape = self.shape(x) | |||||
| if len(x_shape) == 3: | |||||
| x = self.expand_dims(x, 2) | |||||
| n, _, h, w = self.shape(x) | |||||
| h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0], | |||||
| self.padding[0] + self.padding[1]) | |||||
| w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1], | |||||
| self.padding[2] + self.padding[3]) | |||||
| if self.has_bias: | |||||
| return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)), | |||||
| self.bias) | |||||
| output = self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)) | |||||
| if len(x_shape) == 3: | |||||
| output = self.squeeze(output) | |||||
| return output | |||||
| def extend_repr(self): | |||||
| s = 'input_channels={}, output_channels={}, kernel_size={},' \ | |||||
| 'stride={}, pad_mode={}, padding={}, dilation={}, ' \ | |||||
| 'group={}, has_bias={},' \ | |||||
| 'weight_init={}, bias_init={}'.format(self.in_channels, | |||||
| self.out_channels, | |||||
| self.kernel_size, | |||||
| self.stride, | |||||
| self.pad_mode, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.group, | |||||
| self.has_bias, | |||||
| self.weight, | |||||
| self.bias) | |||||
| return s | |||||
| class DepthwiseConv2d(Cell): | class DepthwiseConv2d(Cell): | ||||
| r""" | r""" | ||||
| 2D depthwise convolution layer. | 2D depthwise convolution layer. | ||||
| @@ -780,7 +780,9 @@ class Conv2D(PrimitiveWithInfer): | |||||
| mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , | mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , | ||||
| 2 deconvolution, 3 depthwise convolution. Default: 1. | 2 deconvolution, 3 depthwise convolution. Default: 1. | ||||
| pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". | pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". | ||||
| pad (int): The pad value to fill. Default: 0. | |||||
| pad (Union(int, tuple[int])): The pad value to fill. Default: 0. If `pad` is one integer, the padding of | |||||
| top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding | |||||
| of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding. | |||||
| stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1. | stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1. | ||||
| dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1. | dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1. | ||||
| group (int): Split input into groups. Default: 1. | group (int): Split input into groups. Default: 1. | ||||
| @@ -820,11 +822,19 @@ class Conv2D(PrimitiveWithInfer): | |||||
| self.add_prim_attr('stride', self.stride) | self.add_prim_attr('stride', self.stride) | ||||
| self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | ||||
| self.add_prim_attr('dilation', self.dilation) | self.add_prim_attr('dilation', self.dilation) | ||||
| validator.check_value_type('pad', pad, (int,), self.name) | |||||
| validator.check_value_type('pad', pad, (int, tuple), self.name) | |||||
| if isinstance(pad, int): | |||||
| pad = (pad,) * 4 | |||||
| else: | |||||
| validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) | |||||
| self.padding = pad | |||||
| self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) | self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) | ||||
| self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) | |||||
| if pad_mode != 'pad' and pad != (0, 0, 0, 0): | |||||
| raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||||
| if self.pad_mode == 'pad': | if self.pad_mode == 'pad': | ||||
| validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) | |||||
| for item in pad: | |||||
| validator.check_integer('pad item', item, 0, Rel.GE, self.name) | |||||
| self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | ||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| @@ -862,11 +872,11 @@ class Conv2D(PrimitiveWithInfer): | |||||
| pad_left = math.floor(pad_needed_w / 2) | pad_left = math.floor(pad_needed_w / 2) | ||||
| pad_right = pad_needed_w - pad_left | pad_right = pad_needed_w - pad_left | ||||
| elif self.pad_mode == 'pad': | elif self.pad_mode == 'pad': | ||||
| pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad | |||||
| pad_top, pad_bottom, pad_left, pad_right = self.padding | |||||
| h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ | |||||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ | |||||
| / stride_h | / stride_h | ||||
| w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ | |||||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ | |||||
| / stride_w | / stride_w | ||||
| h_out = math.floor(h_out) | h_out = math.floor(h_out) | ||||
| w_out = math.floor(w_out) | w_out = math.floor(w_out) | ||||
| @@ -1277,7 +1287,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||||
| out_channel (int): The dimensionality of the output space. | out_channel (int): The dimensionality of the output space. | ||||
| kernel_size (Union[int, tuple[int]]): The size of the convolution window. | kernel_size (Union[int, tuple[int]]): The size of the convolution window. | ||||
| pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". | pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". | ||||
| pad (int): The pad value to fill. Default: 0. | |||||
| pad (Union[int, tuple[int]]): The pad value to fill. Default: 0. If `pad` is one integer, the padding of | |||||
| top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding | |||||
| of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding. | |||||
| mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , | mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , | ||||
| 2 deconvolution, 3 depthwise convolution. Default: 1. | 2 deconvolution, 3 depthwise convolution. Default: 1. | ||||
| stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1. | stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1. | ||||
| @@ -1314,9 +1326,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||||
| self.add_prim_attr('stride', self.stride) | self.add_prim_attr('stride', self.stride) | ||||
| self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | ||||
| self.add_prim_attr('dilation', self.dilation) | self.add_prim_attr('dilation', self.dilation) | ||||
| validator.check_value_type('pad', pad, (int,), self.name) | |||||
| validator.check_value_type('pad', pad, (int, tuple), self.name) | |||||
| if isinstance(pad, int): | |||||
| pad = (pad,) * 4 | |||||
| self.pad = pad | |||||
| else: | |||||
| validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) | |||||
| self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) | self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) | ||||
| self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) | |||||
| if pad_mode != 'pad' and pad != (0, 0, 0, 0): | |||||
| raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||||
| if self.pad_mode == 'pad': | |||||
| for item in pad: | |||||
| validator.check_integer('pad item', item, 0, Rel.GE, self.name) | |||||
| pad_mode = pad_mode.upper() | pad_mode = pad_mode.upper() | ||||
| self.add_prim_attr('pad_mode', pad_mode) | self.add_prim_attr('pad_mode', pad_mode) | ||||
| self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | ||||
| @@ -1358,7 +1382,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||||
| pad_right = pad_needed_w - pad_left | pad_right = pad_needed_w - pad_left | ||||
| pad_list = (pad_top, pad_bottom, pad_left, pad_right) | pad_list = (pad_top, pad_bottom, pad_left, pad_right) | ||||
| elif self.pad_mode == 'PAD': | elif self.pad_mode == 'PAD': | ||||
| pad_list = (self.pad,) * 4 | |||||
| pad_list = self.pad | |||||
| self.add_prim_attr('pad_list', pad_list) | self.add_prim_attr('pad_list', pad_list) | ||||
| out = { | out = { | ||||
| 'value': None, | 'value': None, | ||||
| @@ -22,11 +22,22 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters | |||||
| def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): | def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): | ||||
| """Rearranges an image to row vector""" | """Rearranges an image to row vector""" | ||||
| if isinstance(pad, int): | |||||
| pad_top = pad | |||||
| pad_bottom = pad | |||||
| pad_left = pad | |||||
| pad_right = pad | |||||
| elif isinstance(pad, tuple) and len(pad) == 4: | |||||
| pad_top, pad_bottom, pad_left, pad_right = pad | |||||
| else: | |||||
| raise ValueError(f"The \'pad\' should be an int number or " | |||||
| f"a tuple of two or four int numbers, but got {pad}") | |||||
| batch_num, channel, height, width = img.shape | batch_num, channel, height, width = img.shape | ||||
| out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1 | |||||
| out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1 | |||||
| out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1 | |||||
| out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1 | |||||
| img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant') | |||||
| img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant') | |||||
| col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) | col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) | ||||
| for y in range(filter_h): | for y in range(filter_h): | ||||
| @@ -43,10 +54,21 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): | |||||
| def conv2d(x, weight, bias=None, stride=1, pad=0, | def conv2d(x, weight, bias=None, stride=1, pad=0, | ||||
| dilation=1, groups=1, padding_mode='zeros'): | dilation=1, groups=1, padding_mode='zeros'): | ||||
| """Convolution 2D""" | """Convolution 2D""" | ||||
| if isinstance(pad, int): | |||||
| pad_top = pad | |||||
| pad_bottom = pad | |||||
| pad_left = pad | |||||
| pad_right = pad | |||||
| elif isinstance(pad, tuple) and len(pad) == 4: | |||||
| pad_top, pad_bottom, pad_left, pad_right = pad | |||||
| else: | |||||
| raise ValueError(f"The \'pad\' should be an int number or " | |||||
| f"a tuple of two or four int numbers, but got {pad}") | |||||
| batch_num, _, x_h, x_w = x.shape | batch_num, _, x_h, x_w = x.shape | ||||
| filter_num, _, filter_h, filter_w = weight.shape | filter_num, _, filter_h, filter_w = weight.shape | ||||
| out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2]) | |||||
| out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3]) | |||||
| out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2]) | |||||
| out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3]) | |||||
| col = im2col(x, filter_h, filter_w, stride, pad, dilation) | col = im2col(x, filter_h, filter_w, stride, pad, dilation) | ||||
| col_w = np.reshape(weight, (filter_num, -1)).T | col_w = np.reshape(weight, (filter_num, -1)).T | ||||
| out = np.dot(col, col_w) | out = np.dot(col, col_w) | ||||
| @@ -169,16 +169,32 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): | |||||
| raise ValueError(f"The \'stride\' should be an int number or " | raise ValueError(f"The \'stride\' should be an int number or " | ||||
| f"a tuple of two or four int numbers, but got {stride}") | f"a tuple of two or four int numbers, but got {stride}") | ||||
| if isinstance(pad, int): | |||||
| pad_top = pad | |||||
| pad_bottom = pad | |||||
| pad_left = pad | |||||
| pad_right = pad | |||||
| elif isinstance(pad, tuple) and len(pad) == 2: | |||||
| pad_top = pad[0] | |||||
| pad_bottom = pad[0] | |||||
| pad_left = pad[1] | |||||
| pad_right = pad[1] | |||||
| elif isinstance(pad, tuple) and len(pad) == 4: | |||||
| pad_top, pad_bottom, pad_left, pad_right = pad | |||||
| else: | |||||
| raise ValueError(f"The \'pad\' should be an int number or " | |||||
| f"a tuple of two or four int numbers, but got {pad}") | |||||
| batch_num, channel, height, width = input_shape | batch_num, channel, height, width = input_shape | ||||
| out_h = (height + 2 * pad - filter_h) // stride_h + 1 | |||||
| out_w = (width + 2 * pad - filter_w) // stride_w + 1 | |||||
| out_h = (height + pad_top + pad_bottom - filter_h) // stride_h + 1 | |||||
| out_w = (width + pad_left + pad_right - filter_w) // stride_w + 1 | |||||
| col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \ | col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \ | ||||
| .transpose(0, 3, 4, 5, 1, 2) | .transpose(0, 3, 4, 5, 1, 2) | ||||
| img = np.zeros((batch_num, | img = np.zeros((batch_num, | ||||
| channel, | channel, | ||||
| height + 2 * pad + stride_h - 1, | |||||
| width + 2 * pad + stride_w - 1)) \ | |||||
| height + pad_top + pad_bottom + stride_h - 1, | |||||
| width + pad_left + pad_right + stride_w - 1)) \ | |||||
| .astype(col.dtype) | .astype(col.dtype) | ||||
| for y in range(filter_h): | for y in range(filter_h): | ||||
| y_max = y + stride_h * out_h | y_max = y + stride_h * out_h | ||||
| @@ -186,7 +202,7 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): | |||||
| x_max = x + stride_h * out_w | x_max = x + stride_h * out_w | ||||
| img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :] | img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :] | ||||
| return img[:, :, pad:height + pad, pad:width + pad] | |||||
| return img[:, :, pad_top:height + pad_bottom, pad_left:width + pad_right] | |||||
| def convolve(x, w, b=None, pad_mode="valid"): | def convolve(x, w, b=None, pad_mode="valid"): | ||||
| @@ -243,10 +259,21 @@ def conv2d(x, weight, bias=None, stride=1, pad=0, | |||||
| dilation_h = dilation[0] | dilation_h = dilation[0] | ||||
| dilation_w = dilation[1] | dilation_w = dilation[1] | ||||
| if isinstance(pad, int): | |||||
| pad_top = pad | |||||
| pad_bottom = pad | |||||
| pad_left = pad | |||||
| pad_right = pad | |||||
| elif isinstance(pad, tuple) and len(pad) == 4: | |||||
| pad_top, pad_bottom, pad_left, pad_right = pad | |||||
| else: | |||||
| raise ValueError(f"The \'pad\' should be an int number or " | |||||
| f"a tuple of two or four int numbers, but got {pad}") | |||||
| batch_num, _, x_h, x_w = x.shape | batch_num, _, x_h, x_w = x.shape | ||||
| filter_num, _, filter_h, filter_w = weight.shape | filter_num, _, filter_h, filter_w = weight.shape | ||||
| out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h) | |||||
| out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w) | |||||
| out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h) | |||||
| out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w) | |||||
| col = im2col(x, filter_h, filter_w, stride, pad, dilation) | col = im2col(x, filter_h, filter_w, stride, pad, dilation) | ||||
| col_w = np.reshape(weight, (filter_num, -1)).T | col_w = np.reshape(weight, (filter_num, -1)).T | ||||
| out = np.dot(col, col_w) | out = np.dot(col, col_w) | ||||
| @@ -348,11 +375,22 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): | |||||
| raise ValueError(f"The \'dilation\' should be an int number or " | raise ValueError(f"The \'dilation\' should be an int number or " | ||||
| f"a tuple of two or four int numbers, but got {dilation}") | f"a tuple of two or four int numbers, but got {dilation}") | ||||
| if isinstance(pad, int): | |||||
| pad_top = pad | |||||
| pad_bottom = pad | |||||
| pad_left = pad | |||||
| pad_right = pad | |||||
| elif isinstance(pad, tuple) and len(pad) == 4: | |||||
| pad_top, pad_bottom, pad_left, pad_right = pad | |||||
| else: | |||||
| raise ValueError(f"The \'pad\' should be an int number or " | |||||
| f"a tuple of two or four int numbers, but got {pad}") | |||||
| batch_num, channel, height, width = img.shape | batch_num, channel, height, width = img.shape | ||||
| out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1 | |||||
| out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1 | |||||
| out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1 | |||||
| out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1 | |||||
| img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant') | |||||
| img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant') | |||||
| col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) | col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) | ||||
| for y in range(filter_h): | for y in range(filter_h): | ||||