diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 6c1b19a110..098489a91d 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -18,7 +18,7 @@ Layer. The high-level components(Cells) used to construct the neural network. """ from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish -from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm +from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, GlobalBatchNorm from .container import SequentialCell, CellList from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM @@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM, PSNR __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', - 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', + 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SequentialCell', 'CellList', 'Conv2d', 'Conv2dTranspose', 'LSTM', diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 4aafaf031e..6456a3603d 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -20,8 +20,11 @@ from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype import mindspore.context as context -from mindspore._checkparam import check_int_positive, check_bool, check_typename +from mindspore._checkparam import check_bool, check_typename from mindspore._extends import cell_attr_register +from mindspore.communication.management import get_group_size, get_rank +from mindspore.communication import management +from mindspore._checkparam import check_int_positive from ..cell import Cell @@ -30,6 +33,7 @@ class _BatchNorm(Cell): @cell_attr_register def __init__(self, num_features, + group=1, eps=1e-5, momentum=0.9, affine=True, @@ -56,6 +60,21 @@ class _BatchNorm(Cell): gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) + self.group = check_int_positive(group) + if self.group != 1: + self.rank_id = get_rank() + self.rank_size = get_group_size() + self.device_list = [i for i in range(0, self.rank_size)] + self.rank_list = self.list_group(self.device_list, self.group) + self.rank_list_idx = len(self.rank_list) + for i in range(self.rank_list_idx): + if self.rank_id in self.rank_list[i] and self.group != 1: + self.is_global = True + management.create_group('group' + str(i), self.rank_list[i]) + self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) + self.shape = P.Shape() + self.reduce_mean = P.ReduceMean() + self.square = P.Square() if context.get_context("enable_ge"): self.is_ge_backend = True @@ -82,22 +101,53 @@ class _BatchNorm(Cell): def _check_data_dim(self, x): raise NotImplementedError + def list_group(self, world_rank, group_size): + if group_size > get_group_size(): + raise ValueError("group size can not be greater than local rank size, group size is {}, " + "local_rank_size is {}".format(group_size, get_group_size())) + if len(world_rank) % group_size != 0: + raise ValueError("please make your group size correct.") + world_rank_list = zip(*(iter(world_rank),) *group_size) + group_list = [list(i) for i in world_rank_list] + return group_list + def construct(self, x): if self.training and self.use_batch_statistics: if self.is_ge_backend: - y, batch_mean, batch_var, _, _ = \ - self.bn_train(x, - self.gamma, - self.beta, - None, - None) - - mean_sub = self.sub_mean(self.moving_mean, batch_mean) - temp_mean = self.mul_mean(mean_sub, self.momentum) - mean_sub2 = self.sub_var(self.moving_variance, batch_var) - temp_variance = self.mul_var(mean_sub2, self.momentum) - y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) - y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) + if self.is_global: + x_mean = self.reduce_mean(x) + x_mean_square = self.reduce_mean(self.square(x)) + global_batch_mean = self.all_reduce(x_mean) / self.group + global_batch_mean_square = self.all_reduce(x_mean_square) / self.group + global_mean = global_batch_mean + global_var = global_batch_mean_square - self.square(global_batch_mean) + y, batch_mean, batch_var, _, _ = \ + self.bn_train(x, + self.gamma, + self.beta, + None, + None) + + mean_sub = self.sub_mean(self.moving_mean, global_mean) + temp_mean = self.mul_mean(mean_sub, self.momentum) + mean_sub2 = self.sub_var(self.moving_variance, global_var) + temp_variance = self.mul_var(mean_sub2, self.momentum) + y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) + y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) + else: + y, batch_mean, batch_var, _, _ = \ + self.bn_train(x, + self.gamma, + self.beta, + None, + None) + + mean_sub = self.sub_mean(self.moving_mean, batch_mean) + temp_mean = self.mul_mean(mean_sub, self.momentum) + mean_sub2 = self.sub_var(self.moving_variance, batch_var) + temp_variance = self.mul_var(mean_sub2, self.momentum) + y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) + y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) else: y = self.bn_train(x, self.gamma, @@ -221,6 +271,55 @@ class BatchNorm2d(_BatchNorm): pass +class GlobalBatchNorm(_BatchNorm): + r""" + Global normalization layer over a N-dimension input. + + Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation + only normalize the data within each device. Global normalization will normalize the input within the group. + It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by + Reducing Internal Covariate Shift `_. It rescales and recenters the + feature using a mini-batch of data and the learned parameters which can be described in the following formula. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + num_features (int): `C` from an expected input of size (N, C, H, W). + group (int): The number of device in each group. + eps (float): A value added to the denominator for numerical stability. Default: 1e-5. + momentum (float): A floating hyperparameter of the momentum for the + running_mean and running_var computation. Default: 0.9. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use + the mean value and variance value of specified value. Default: True. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4) + >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) + >>> global_bn_op(input) + """ + def _check_data_dim(self, x): + if x.dim == 0: + pass + class LayerNorm(Cell): r""" Applies Layer Normalization over a mini-batch of inputs.