From 3a85c18ff0554965f25173bf295f2cd075f9f228 Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Fri, 23 Apr 2021 15:15:35 +0800 Subject: [PATCH] amend bn --- mindspore/nn/layer/normalization.py | 40 ++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 0d7dc01bfd..fd3db41db1 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -39,6 +39,7 @@ __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm' SYNC_BN_GROUP_NAME = "" + class _BatchNorm(Cell): """Batch Normalization base class.""" @@ -69,6 +70,8 @@ class _BatchNorm(Cell): if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = use_batch_statistics + if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool): + raise ValueError("use_batch_statistics should be a boolean value or None.") self.num_features = num_features self.eps = eps self.moving_mean = Parameter(initializer( @@ -95,7 +98,7 @@ class _BatchNorm(Cell): if self.rank_id in self.rank_list[i]: self.is_global = True if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i) + SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) # for SyncBatchNorm if self.process_groups != 0: @@ -105,7 +108,7 @@ class _BatchNorm(Cell): validator.check_isinstance("process_groups", self.process_groups, list) self._check_rank_ids(self.process_groups, self.rank_size) for i in range(len(self.process_groups)): - validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list) + validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list) self.group_device_num = len(self.process_groups[i]) if self.rank_id in self.process_groups[i] and self.group_device_num > 1: self.is_global = True @@ -180,11 +183,20 @@ class _BatchNorm(Cell): def construct(self, x): _shape_check_bn(self.shape(x), self.input_dims) if self.use_batch_statistics is None: - flag = self.training - else: - flag = self.use_batch_statistics - - if flag: + if self.training: + return self.bn_train(x, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + if not self.training: + return self.bn_infer(x, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + + if self.use_batch_statistics is True: return self.bn_train(x, self.gamma, self.beta, @@ -365,10 +377,14 @@ class BatchNorm2d(_BatchNorm): The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. - use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, - use the mean value and variance value of specified value. If None, the training process will use the mean - and variance of current batch data and track the running mean and variance, the evaluation process will use - the running mean and variance. Default: None. + use_batch_statistics (bool): + If true, use the mean value and variance value of current batch data and track running mean + and running varance. + If false, use the mean value and variance value of specified value, and not track statistical value. + If None, The use_batch_statistics is automatically assigned process according to the training and eval mode. + During training, batchnorm2d process will be same with use_batch_statistics=True. + Contrarily, in eval, batchnorm2d process will be same with use_batch_statistics=False. + Default: None. data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: 'NCHW'. @@ -527,7 +543,7 @@ class BatchNorm3d(Cell): def construct(self, input_x): x_shape = F.shape(input_x) _check_3d_shape(x_shape) - input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) + input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) bn2d_out = self.bn2d(input_x) bn3d_out = self.reshape(bn2d_out, x_shape) return bn3d_out