diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 66f17e3f38..dd4ac67273 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -116,53 +116,44 @@ class _BatchNorm(Cell): group_list = [list(i) for i in world_rank_list] return group_list - def _global_sync(self, x): - """calculate global batch normalization output""" + def _shape_infer(self, x): + """global batch normalization shape and axes infer""" if len(self.shape(x)) == 4: - axes = (0, 2, 3) + axes = (0,2,3) re_shape = (1, self.num_features, 1, 1) - x_mean = self.reduce_mean(x, axes) - x_mean_square = self.reduce_mean(self.square(x), axes) - global_batch_mean = self.all_reduce(x_mean) / self.group - global_batch_mean_square = self.all_reduce(x_mean_square) / self.group - global_mean = global_batch_mean - global_var = global_batch_mean_square - self.square(global_mean) - var_sqrt = self.sqrt(global_var + self.eps) - mean_first = (x - global_mean) / var_sqrt - y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) - - mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) - tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) - mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) - tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) - y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) - y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) else: axes = (0,) re_shape = (1, self.num_features) - x_mean = self.reduce_mean(x, axes) - x_mean_square = self.reduce_mean(self.square(x), axes) - global_batch_mean = self.all_reduce(x_mean) / self.group - global_batch_mean_square = self.all_reduce(x_mean_square) / self.group - global_mean = global_batch_mean - global_var = global_batch_mean_square - self.square(global_mean) - var_sqrt = self.sqrt(global_var + self.eps) - mean_first = (x - global_mean) / var_sqrt - y = mean_first * self.gamma + self.beta - - mean_sub = self.sub_mean(self.moving_mean, global_mean) - temp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) - mean_sub2 = self.sub_var(self.moving_variance, global_var) - temp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) - y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), temp_mean)) - y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), temp_variance)) + return axes, re_shape + + def _global_sync(self, x, axes, re_shape): + """calculate global batch normalization output""" + axes = (0, 2, 3) + re_shape = (1, self.num_features, 1, 1) + x_mean = self.reduce_mean(x, axes) + x_mean_square = self.reduce_mean(self.square(x), axes) + global_batch_mean = self.all_reduce(x_mean) / self.group + global_batch_mean_square = self.all_reduce(x_mean_square) / self.group + global_mean = global_batch_mean + global_var = global_batch_mean_square - self.square(global_mean) + var_sqrt = self.sqrt(global_var + self.eps) + mean_first = (x - global_mean) / var_sqrt + y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) + + mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) + tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) + mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) + tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) + y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) + y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) return y def construct(self, x): if self.training and self.use_batch_statistics: if self.is_ge_backend: if self.is_global: - y = self._global_sync(x) + axes, re_shape = self._shape_infer(x) + y = self._global_sync(x, axes, re_shape) else: y, batch_mean, batch_var, _, _ = \ self.bn_train(x,