From: @jiangzg001 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -420,6 +420,100 @@ class BatchNorm2d(_BatchNorm): | |||
| pass | |||
| class BatchNorm3d(Cell): | |||
| r""" | |||
| Batch normalization layer over a 5D input. | |||
| Batch Normalization is widely used in convolutional networks. This layer | |||
| applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with | |||
| additional channel dimension) to avoid internal covariate shift. | |||
| .. math:: | |||
| y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
| Note: | |||
| The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be | |||
| changed after net was initilized. | |||
| Note that the formula for updating the running_mean and running_var is | |||
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`, | |||
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. | |||
| Args: | |||
| num_features (int): `C` from an expected input of size (N, C, D, H, W). | |||
| eps (float): A value added to the denominator for numerical stability. Default: 1e-5. | |||
| momentum (float): A floating hyperparameter of the momentum for the | |||
| running_mean and running_var computation. Default: 0.9. | |||
| affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. | |||
| gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'ones'. | |||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'zeros'. | |||
| moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'zeros'. | |||
| moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'ones'. | |||
| use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, | |||
| use the mean value and variance value of specified value. If None, the training process will use the mean | |||
| and variance of current batch data and track the running mean and variance, the evaluation process will use | |||
| the running mean and variance. Default: None. | |||
| data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> net = nn.BatchNorm3d(num_features=3) | |||
| >>> np.random.seed(0) | |||
| >>> input = Tensor(np.random.randint(0, 255, [16, 3, 10, 32, 32]), mindspore.float32) | |||
| >>> output = net(input) | |||
| >>> print(output.shape) | |||
| (16, 3, 10, 32, 32) | |||
| """ | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| moving_mean_init='zeros', | |||
| moving_var_init='ones', | |||
| use_batch_statistics=None, | |||
| data_format='NCDHW'): | |||
| super(BatchNorm3d, self).__init__() | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) | |||
| self.bn2d = BatchNorm2d(num_features=num_features, | |||
| eps=eps, | |||
| momentum=momentum, | |||
| affine=affine, | |||
| gamma_init=gamma_init, | |||
| beta_init=beta_init, | |||
| moving_mean_init=moving_mean_init, | |||
| moving_var_init=moving_var_init, | |||
| use_batch_statistics=use_batch_statistics, | |||
| data_format="NCHW") | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, input_x): | |||
| x_shape = self.shape(input_x) | |||
| input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) | |||
| bn2d_out = self.bn2d(input_x) | |||
| bn3d_out = self.reshape(bn2d_out, x_shape) | |||
| return bn3d_out | |||
| class GlobalBatchNorm(_BatchNorm): | |||
| r""" | |||
| Global normalization layer over a N-dimension input. | |||
| @@ -6837,6 +6837,18 @@ class Conv3D(PrimitiveWithInfer): | |||
| pad = (pad,) * 6 | |||
| validator.check_equal_int(len(pad), 6, 'pad size', self.name) | |||
| self.padding = pad | |||
| validator.check_int_range(self.padding[0], 0, kernel_size[0], Rel.INC_LEFT, | |||
| 'pad_d belonging [0, kernel_size_d)', self.name) | |||
| validator.check_int_range(self.padding[1], 0, kernel_size[0], Rel.INC_LEFT, | |||
| 'pad_d belonging [0, kernel_size_d)', self.name) | |||
| validator.check_int_range(self.padding[2], 0, kernel_size[1], Rel.INC_LEFT, | |||
| 'pad_h belonging [0, kernel_size_h)', self.name) | |||
| validator.check_int_range(self.padding[3], 0, kernel_size[1], Rel.INC_LEFT, | |||
| 'pad_h belonging [0, kernel_size_h)', self.name) | |||
| validator.check_int_range(self.padding[4], 0, kernel_size[2], Rel.INC_LEFT, | |||
| 'pad_w belonging [0, kernel_size_w)', self.name) | |||
| validator.check_int_range(self.padding[5], 0, kernel_size[2], Rel.INC_LEFT, | |||
| 'pad_w belonging [0, kernel_size_w)', self.name) | |||
| self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) | |||
| self.add_prim_attr('pad_mode', self.pad_mode) | |||
| @@ -7130,6 +7142,18 @@ class Conv3DTranspose(PrimitiveWithInfer): | |||
| self.pad_list = pad | |||
| for item in self.pad_list: | |||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||
| validator.check_int_range(self.pad_list[0], 0, kernel_size[0], Rel.INC_LEFT, | |||
| 'pad_d belonging [0, kernel_size_d)', self.name) | |||
| validator.check_int_range(self.pad_list[1], 0, kernel_size[0], Rel.INC_LEFT, | |||
| 'pad_d belonging [0, kernel_size_d)', self.name) | |||
| validator.check_int_range(self.pad_list[2], 0, kernel_size[1], Rel.INC_LEFT, | |||
| 'pad_h belonging [0, kernel_size_h)', self.name) | |||
| validator.check_int_range(self.pad_list[3], 0, kernel_size[1], Rel.INC_LEFT, | |||
| 'pad_h belonging [0, kernel_size_h)', self.name) | |||
| validator.check_int_range(self.pad_list[4], 0, kernel_size[2], Rel.INC_LEFT, | |||
| 'pad_w belonging [0, kernel_size_w)', self.name) | |||
| validator.check_int_range(self.pad_list[5], 0, kernel_size[2], Rel.INC_LEFT, | |||
| 'pad_w belonging [0, kernel_size_w)', self.name) | |||
| self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | |||
| self.add_prim_attr('mode', self.mode) | |||
| self.group = validator.check_positive_int(group, 'group', self.name) | |||
| @@ -7140,18 +7164,12 @@ class Conv3DTranspose(PrimitiveWithInfer): | |||
| self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name, | |||
| allow_five=True, ret_five=True, greater_zero=False) | |||
| if self.output_padding[2] < 0 or (self.output_padding[2] >= self.dilation[2] | |||
| and self.output_padding[2] >= self.stride[2]): | |||
| raise ValueError("In op, the value of [output_padding D] should be [[0, max(stride D,dilation D))], " | |||
| "but it is {}.".format(self.output_padding[2])) | |||
| if self.output_padding[3] < 0 or (self.output_padding[3] >= self.dilation[3] | |||
| and self.output_padding[3] >= self.stride[3]): | |||
| raise ValueError("In op, the value of [output_padding H] should be [[0, max(stride H,dilation H))], " | |||
| "but it is {}.".format(self.output_padding[3])) | |||
| if self.output_padding[4] < 0 or (self.output_padding[4] >= self.dilation[4] | |||
| and self.output_padding[4] >= self.stride[4]): | |||
| raise ValueError("In op, the value of [output_padding W] should be [[0, max(stride W,dilation W))], " | |||
| "but it is {}.".format(self.output_padding[4])) | |||
| validator.check_int_range(self.output_padding[2], 0, max(self.dilation[2], self.stride[2]), Rel.INC_LEFT, | |||
| 'output_padding_d belonging [0, max(stride_d, dilation_d))', self.name) | |||
| validator.check_int_range(self.output_padding[3], 0, max(self.dilation[3], self.stride[3]), Rel.INC_LEFT, | |||
| 'output_padding_h belonging [0, max(stride_h,dilation_h))', self.name) | |||
| validator.check_int_range(self.output_padding[4], 0, max(self.dilation[4], self.stride[4]), Rel.INC_LEFT, | |||
| 'output_padding_w belonging [0, max(stride_w,dilation_w))', self.name) | |||
| self.add_prim_attr('output_padding', self.output_padding) | |||
| def __infer__(self, x, w, b=None): | |||
| @@ -28,6 +28,7 @@ from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.ops.operations import _quant_ops as Q | |||
| from mindspore.ops.operations import nn_ops as nps | |||
| from mindspore.nn.layer import normalization | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| @@ -229,6 +230,18 @@ class Moments(nn.Cell): | |||
| return mean, variance | |||
| class BatchNorm3d(nn.Cell): | |||
| """BatchNorm3d net definition""" | |||
| def __init__(self, num_features): | |||
| super(BatchNorm3d, self).__init__() | |||
| self.bn3d = normalization.BatchNorm3d(num_features=num_features) | |||
| def construct(self, input_x): | |||
| bn3d_out = self.bn3d(input_x) | |||
| return bn3d_out | |||
| class ClipByNorm(nn.Cell): | |||
| """ClipByNorm net definition""" | |||
| @@ -1240,6 +1253,10 @@ test_case_math_ops = [ | |||
| 'block': Moments(axis=(), keep_dims=False), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('BatchNorm3d', { | |||
| 'block': BatchNorm3d(num_features=3), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 3, 3, 5, 4).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('Conv3D', { | |||
| 'block': Conv3D(out_channel=32, kernel_size=(4, 3, 3), mode=1, pad_mode='valid', pad=0, | |||
| stride=1, dilation=1, group=1, data_format="NCDHW"), | |||