From 9f950fb16c40d89988001e8a3a2beeb1a548c57c Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Sat, 26 Dec 2020 16:37:15 +0800 Subject: [PATCH] add batchnorm3d --- mindspore/nn/layer/normalization.py | 94 +++++++++++++++++++++++++++++ tests/ut/python/ops/test_ops.py | 17 ++++++ 2 files changed, 111 insertions(+) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index b326284128..795720d756 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -420,6 +420,100 @@ class BatchNorm2d(_BatchNorm): pass +class BatchNorm3d(Cell): + r""" + Batch normalization layer over a 5D input. + + Batch Normalization is widely used in convolutional networks. This layer + applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with + additional channel dimension) to avoid internal covariate shift. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Note: + The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be + changed after net was initilized. + Note that the formula for updating the running_mean and running_var is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. + + Args: + num_features (int): `C` from an expected input of size (N, C, D, H, W). + eps (float): A value added to the denominator for numerical stability. Default: 1e-5. + momentum (float): A floating hyperparameter of the momentum for the + running_mean and running_var computation. Default: 0.9. + affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, + use the mean value and variance value of specified value. If None, the training process will use the mean + and variance of current batch data and track the running mean and variance, the evaluation process will use + the running mean and variance. Default: None. + data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> net = nn.BatchNorm3d(num_features=3) + >>> np.random.seed(0) + >>> input = Tensor(np.random.randint(0, 255, [16, 3, 10, 32, 32]), mindspore.float32) + >>> output = net(input) + >>> print(output.shape) + (16, 3, 10, 32, 32) + """ + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=None, + data_format='NCDHW'): + super(BatchNorm3d, self).__init__() + self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) + self.bn2d = BatchNorm2d(num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + gamma_init=gamma_init, + beta_init=beta_init, + moving_mean_init=moving_mean_init, + moving_var_init=moving_var_init, + use_batch_statistics=use_batch_statistics, + data_format="NCHW") + self.shape = P.Shape() + self.reshape = P.Reshape() + + def construct(self, input_x): + x_shape = self.shape(input_x) + input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) + bn2d_out = self.bn2d(input_x) + bn3d_out = self.reshape(bn2d_out, x_shape) + return bn3d_out + + class GlobalBatchNorm(_BatchNorm): r""" Global normalization layer over a N-dimension input. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 69a954e2a7..6c647b3b67 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -28,6 +28,7 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _quant_ops as Q from mindspore.ops.operations import nn_ops as nps +from mindspore.nn.layer import normalization from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ @@ -229,6 +230,18 @@ class Moments(nn.Cell): return mean, variance +class BatchNorm3d(nn.Cell): + """BatchNorm3d net definition""" + + def __init__(self, num_features): + super(BatchNorm3d, self).__init__() + self.bn3d = normalization.BatchNorm3d(num_features=num_features) + + def construct(self, input_x): + bn3d_out = self.bn3d(input_x) + return bn3d_out + + class ClipByNorm(nn.Cell): """ClipByNorm net definition""" @@ -1240,6 +1253,10 @@ test_case_math_ops = [ 'block': Moments(axis=(), keep_dims=False), 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], 'skip': ['backward']}), + ('BatchNorm3d', { + 'block': BatchNorm3d(num_features=3), + 'desc_inputs': [Tensor(np.random.rand(3, 3, 3, 5, 4).astype(np.float32))], + 'skip': ['backward']}), ('Conv3D', { 'block': Conv3D(out_channel=32, kernel_size=(4, 3, 3), mode=1, pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"),