|
|
|
@@ -39,6 +39,7 @@ __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm' |
|
|
|
|
|
|
|
SYNC_BN_GROUP_NAME = "" |
|
|
|
|
|
|
|
|
|
|
|
class _BatchNorm(Cell): |
|
|
|
"""Batch Normalization base class.""" |
|
|
|
|
|
|
|
@@ -69,6 +70,8 @@ class _BatchNorm(Cell): |
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC": |
|
|
|
raise ValueError("NHWC format only support in GPU target.") |
|
|
|
self.use_batch_statistics = use_batch_statistics |
|
|
|
if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool): |
|
|
|
raise ValueError("use_batch_statistics should be a boolean value or None.") |
|
|
|
self.num_features = num_features |
|
|
|
self.eps = eps |
|
|
|
self.moving_mean = Parameter(initializer( |
|
|
|
@@ -95,7 +98,7 @@ class _BatchNorm(Cell): |
|
|
|
if self.rank_id in self.rank_list[i]: |
|
|
|
self.is_global = True |
|
|
|
if SYNC_BN_GROUP_NAME == "": |
|
|
|
SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i) |
|
|
|
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) |
|
|
|
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) |
|
|
|
# for SyncBatchNorm |
|
|
|
if self.process_groups != 0: |
|
|
|
@@ -105,7 +108,7 @@ class _BatchNorm(Cell): |
|
|
|
validator.check_isinstance("process_groups", self.process_groups, list) |
|
|
|
self._check_rank_ids(self.process_groups, self.rank_size) |
|
|
|
for i in range(len(self.process_groups)): |
|
|
|
validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list) |
|
|
|
validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list) |
|
|
|
self.group_device_num = len(self.process_groups[i]) |
|
|
|
if self.rank_id in self.process_groups[i] and self.group_device_num > 1: |
|
|
|
self.is_global = True |
|
|
|
@@ -183,11 +186,20 @@ class _BatchNorm(Cell): |
|
|
|
def construct(self, x): |
|
|
|
_shape_check_bn(self.shape(x), self.input_dims) |
|
|
|
if self.use_batch_statistics is None: |
|
|
|
flag = self.training |
|
|
|
else: |
|
|
|
flag = self.use_batch_statistics |
|
|
|
|
|
|
|
if flag: |
|
|
|
if self.training: |
|
|
|
return self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
if not self.training: |
|
|
|
return self.bn_infer(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
|
|
|
|
if self.use_batch_statistics is True: |
|
|
|
return self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
@@ -368,10 +380,14 @@ class BatchNorm2d(_BatchNorm): |
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. |
|
|
|
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. |
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. |
|
|
|
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, |
|
|
|
use the mean value and variance value of specified value. If None, the training process will use the mean |
|
|
|
and variance of current batch data and track the running mean and variance, the evaluation process will use |
|
|
|
the running mean and variance. Default: None. |
|
|
|
use_batch_statistics (bool): |
|
|
|
If true, use the mean value and variance value of current batch data and track running mean |
|
|
|
and running varance. |
|
|
|
If false, use the mean value and variance value of specified value, and not track statistical value. |
|
|
|
If None, The use_batch_statistics is automatically assigned process according to the training and eval mode. |
|
|
|
During training, batchnorm2d process will be same with use_batch_statistics=True. |
|
|
|
Contrarily, in eval, batchnorm2d process will be same with use_batch_statistics=False. |
|
|
|
Default: None. |
|
|
|
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. |
|
|
|
Default: 'NCHW'. |
|
|
|
|
|
|
|
@@ -530,7 +546,7 @@ class BatchNorm3d(Cell): |
|
|
|
def construct(self, input_x): |
|
|
|
x_shape = F.shape(input_x) |
|
|
|
_check_3d_shape(x_shape) |
|
|
|
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) |
|
|
|
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) |
|
|
|
bn2d_out = self.bn2d(input_x) |
|
|
|
bn3d_out = self.reshape(bn2d_out, x_shape) |
|
|
|
return bn3d_out |
|
|
|
|