| @@ -39,6 +39,7 @@ __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm' | |||
| SYNC_BN_GROUP_NAME = "" | |||
| class _BatchNorm(Cell): | |||
| """Batch Normalization base class.""" | |||
| @@ -69,6 +70,8 @@ class _BatchNorm(Cell): | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.use_batch_statistics = use_batch_statistics | |||
| if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool): | |||
| raise ValueError("use_batch_statistics should be a boolean value or None.") | |||
| self.num_features = num_features | |||
| self.eps = eps | |||
| self.moving_mean = Parameter(initializer( | |||
| @@ -95,7 +98,7 @@ class _BatchNorm(Cell): | |||
| if self.rank_id in self.rank_list[i]: | |||
| self.is_global = True | |||
| if SYNC_BN_GROUP_NAME == "": | |||
| SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i) | |||
| SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) | |||
| management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) | |||
| # for SyncBatchNorm | |||
| if self.process_groups != 0: | |||
| @@ -105,7 +108,7 @@ class _BatchNorm(Cell): | |||
| validator.check_isinstance("process_groups", self.process_groups, list) | |||
| self._check_rank_ids(self.process_groups, self.rank_size) | |||
| for i in range(len(self.process_groups)): | |||
| validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list) | |||
| validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list) | |||
| self.group_device_num = len(self.process_groups[i]) | |||
| if self.rank_id in self.process_groups[i] and self.group_device_num > 1: | |||
| self.is_global = True | |||
| @@ -180,11 +183,20 @@ class _BatchNorm(Cell): | |||
| def construct(self, x): | |||
| _shape_check_bn(self.shape(x), self.input_dims) | |||
| if self.use_batch_statistics is None: | |||
| flag = self.training | |||
| else: | |||
| flag = self.use_batch_statistics | |||
| if flag: | |||
| if self.training: | |||
| return self.bn_train(x, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance)[0] | |||
| if not self.training: | |||
| return self.bn_infer(x, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance)[0] | |||
| if self.use_batch_statistics is True: | |||
| return self.bn_train(x, | |||
| self.gamma, | |||
| self.beta, | |||
| @@ -365,10 +377,14 @@ class BatchNorm2d(_BatchNorm): | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. | |||
| moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. | |||
| use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, | |||
| use the mean value and variance value of specified value. If None, the training process will use the mean | |||
| and variance of current batch data and track the running mean and variance, the evaluation process will use | |||
| the running mean and variance. Default: None. | |||
| use_batch_statistics (bool): | |||
| If true, use the mean value and variance value of current batch data and track running mean | |||
| and running varance. | |||
| If false, use the mean value and variance value of specified value, and not track statistical value. | |||
| If None, The use_batch_statistics is automatically assigned process according to the training and eval mode. | |||
| During training, batchnorm2d process will be same with use_batch_statistics=True. | |||
| Contrarily, in eval, batchnorm2d process will be same with use_batch_statistics=False. | |||
| Default: None. | |||
| data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. | |||
| Default: 'NCHW'. | |||
| @@ -527,7 +543,7 @@ class BatchNorm3d(Cell): | |||
| def construct(self, input_x): | |||
| x_shape = F.shape(input_x) | |||
| _check_3d_shape(x_shape) | |||
| input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) | |||
| input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) | |||
| bn2d_out = self.bn2d(input_x) | |||
| bn3d_out = self.reshape(bn2d_out, x_shape) | |||
| return bn3d_out | |||