| @@ -39,6 +39,7 @@ __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm' | |||||
| SYNC_BN_GROUP_NAME = "" | SYNC_BN_GROUP_NAME = "" | ||||
| class _BatchNorm(Cell): | class _BatchNorm(Cell): | ||||
| """Batch Normalization base class.""" | """Batch Normalization base class.""" | ||||
| @@ -69,6 +70,8 @@ class _BatchNorm(Cell): | |||||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | if context.get_context("device_target") != "GPU" and self.format == "NHWC": | ||||
| raise ValueError("NHWC format only support in GPU target.") | raise ValueError("NHWC format only support in GPU target.") | ||||
| self.use_batch_statistics = use_batch_statistics | self.use_batch_statistics = use_batch_statistics | ||||
| if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool): | |||||
| raise ValueError("use_batch_statistics should be a boolean value or None.") | |||||
| self.num_features = num_features | self.num_features = num_features | ||||
| self.eps = eps | self.eps = eps | ||||
| self.moving_mean = Parameter(initializer( | self.moving_mean = Parameter(initializer( | ||||
| @@ -95,7 +98,7 @@ class _BatchNorm(Cell): | |||||
| if self.rank_id in self.rank_list[i]: | if self.rank_id in self.rank_list[i]: | ||||
| self.is_global = True | self.is_global = True | ||||
| if SYNC_BN_GROUP_NAME == "": | if SYNC_BN_GROUP_NAME == "": | ||||
| SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i) | |||||
| SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) | |||||
| management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) | management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) | ||||
| # for SyncBatchNorm | # for SyncBatchNorm | ||||
| if self.process_groups != 0: | if self.process_groups != 0: | ||||
| @@ -105,7 +108,7 @@ class _BatchNorm(Cell): | |||||
| validator.check_isinstance("process_groups", self.process_groups, list) | validator.check_isinstance("process_groups", self.process_groups, list) | ||||
| self._check_rank_ids(self.process_groups, self.rank_size) | self._check_rank_ids(self.process_groups, self.rank_size) | ||||
| for i in range(len(self.process_groups)): | for i in range(len(self.process_groups)): | ||||
| validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list) | |||||
| validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list) | |||||
| self.group_device_num = len(self.process_groups[i]) | self.group_device_num = len(self.process_groups[i]) | ||||
| if self.rank_id in self.process_groups[i] and self.group_device_num > 1: | if self.rank_id in self.process_groups[i] and self.group_device_num > 1: | ||||
| self.is_global = True | self.is_global = True | ||||
| @@ -180,11 +183,20 @@ class _BatchNorm(Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| _shape_check_bn(self.shape(x), self.input_dims) | _shape_check_bn(self.shape(x), self.input_dims) | ||||
| if self.use_batch_statistics is None: | if self.use_batch_statistics is None: | ||||
| flag = self.training | |||||
| else: | |||||
| flag = self.use_batch_statistics | |||||
| if flag: | |||||
| if self.training: | |||||
| return self.bn_train(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| self.moving_mean, | |||||
| self.moving_variance)[0] | |||||
| if not self.training: | |||||
| return self.bn_infer(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| self.moving_mean, | |||||
| self.moving_variance)[0] | |||||
| if self.use_batch_statistics is True: | |||||
| return self.bn_train(x, | return self.bn_train(x, | ||||
| self.gamma, | self.gamma, | ||||
| self.beta, | self.beta, | ||||
| @@ -365,10 +377,14 @@ class BatchNorm2d(_BatchNorm): | |||||
| The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. | The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. | ||||
| moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. | moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. | ||||
| The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. | The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. | ||||
| use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, | |||||
| use the mean value and variance value of specified value. If None, the training process will use the mean | |||||
| and variance of current batch data and track the running mean and variance, the evaluation process will use | |||||
| the running mean and variance. Default: None. | |||||
| use_batch_statistics (bool): | |||||
| If true, use the mean value and variance value of current batch data and track running mean | |||||
| and running varance. | |||||
| If false, use the mean value and variance value of specified value, and not track statistical value. | |||||
| If None, The use_batch_statistics is automatically assigned process according to the training and eval mode. | |||||
| During training, batchnorm2d process will be same with use_batch_statistics=True. | |||||
| Contrarily, in eval, batchnorm2d process will be same with use_batch_statistics=False. | |||||
| Default: None. | |||||
| data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. | data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. | ||||
| Default: 'NCHW'. | Default: 'NCHW'. | ||||
| @@ -527,7 +543,7 @@ class BatchNorm3d(Cell): | |||||
| def construct(self, input_x): | def construct(self, input_x): | ||||
| x_shape = F.shape(input_x) | x_shape = F.shape(input_x) | ||||
| _check_3d_shape(x_shape) | _check_3d_shape(x_shape) | ||||
| input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) | |||||
| input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) | |||||
| bn2d_out = self.bn2d(input_x) | bn2d_out = self.bn2d(input_x) | ||||
| bn3d_out = self.reshape(bn2d_out, x_shape) | bn3d_out = self.reshape(bn2d_out, x_shape) | ||||
| return bn3d_out | return bn3d_out | ||||