| @@ -116,53 +116,44 @@ class _BatchNorm(Cell): | |||
| group_list = [list(i) for i in world_rank_list] | |||
| return group_list | |||
| def _global_sync(self, x): | |||
| """calculate global batch normalization output""" | |||
| def _shape_infer(self, x): | |||
| """global batch normalization shape and axes infer""" | |||
| if len(self.shape(x)) == 4: | |||
| axes = (0, 2, 3) | |||
| axes = (0,2,3) | |||
| re_shape = (1, self.num_features, 1, 1) | |||
| x_mean = self.reduce_mean(x, axes) | |||
| x_mean_square = self.reduce_mean(self.square(x), axes) | |||
| global_batch_mean = self.all_reduce(x_mean) / self.group | |||
| global_batch_mean_square = self.all_reduce(x_mean_square) / self.group | |||
| global_mean = global_batch_mean | |||
| global_var = global_batch_mean_square - self.square(global_mean) | |||
| var_sqrt = self.sqrt(global_var + self.eps) | |||
| mean_first = (x - global_mean) / var_sqrt | |||
| y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) | |||
| mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) | |||
| tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) | |||
| mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) | |||
| tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) | |||
| y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) | |||
| y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) | |||
| else: | |||
| axes = (0,) | |||
| re_shape = (1, self.num_features) | |||
| x_mean = self.reduce_mean(x, axes) | |||
| x_mean_square = self.reduce_mean(self.square(x), axes) | |||
| global_batch_mean = self.all_reduce(x_mean) / self.group | |||
| global_batch_mean_square = self.all_reduce(x_mean_square) / self.group | |||
| global_mean = global_batch_mean | |||
| global_var = global_batch_mean_square - self.square(global_mean) | |||
| var_sqrt = self.sqrt(global_var + self.eps) | |||
| mean_first = (x - global_mean) / var_sqrt | |||
| y = mean_first * self.gamma + self.beta | |||
| mean_sub = self.sub_mean(self.moving_mean, global_mean) | |||
| temp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) | |||
| mean_sub2 = self.sub_var(self.moving_variance, global_var) | |||
| temp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) | |||
| y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), temp_mean)) | |||
| y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), temp_variance)) | |||
| return axes, re_shape | |||
| def _global_sync(self, x, axes, re_shape): | |||
| """calculate global batch normalization output""" | |||
| axes = (0, 2, 3) | |||
| re_shape = (1, self.num_features, 1, 1) | |||
| x_mean = self.reduce_mean(x, axes) | |||
| x_mean_square = self.reduce_mean(self.square(x), axes) | |||
| global_batch_mean = self.all_reduce(x_mean) / self.group | |||
| global_batch_mean_square = self.all_reduce(x_mean_square) / self.group | |||
| global_mean = global_batch_mean | |||
| global_var = global_batch_mean_square - self.square(global_mean) | |||
| var_sqrt = self.sqrt(global_var + self.eps) | |||
| mean_first = (x - global_mean) / var_sqrt | |||
| y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) | |||
| mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) | |||
| tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) | |||
| mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) | |||
| tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) | |||
| y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) | |||
| y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) | |||
| return y | |||
| def construct(self, x): | |||
| if self.training and self.use_batch_statistics: | |||
| if self.is_ge_backend: | |||
| if self.is_global: | |||
| y = self._global_sync(x) | |||
| axes, re_shape = self._shape_infer(x) | |||
| y = self._global_sync(x, axes, re_shape) | |||
| else: | |||
| y, batch_mean, batch_var, _, _ = \ | |||
| self.bn_train(x, | |||