| @@ -116,53 +116,44 @@ class _BatchNorm(Cell): | |||||
| group_list = [list(i) for i in world_rank_list] | group_list = [list(i) for i in world_rank_list] | ||||
| return group_list | return group_list | ||||
| def _global_sync(self, x): | |||||
| """calculate global batch normalization output""" | |||||
| def _shape_infer(self, x): | |||||
| """global batch normalization shape and axes infer""" | |||||
| if len(self.shape(x)) == 4: | if len(self.shape(x)) == 4: | ||||
| axes = (0, 2, 3) | |||||
| axes = (0,2,3) | |||||
| re_shape = (1, self.num_features, 1, 1) | re_shape = (1, self.num_features, 1, 1) | ||||
| x_mean = self.reduce_mean(x, axes) | |||||
| x_mean_square = self.reduce_mean(self.square(x), axes) | |||||
| global_batch_mean = self.all_reduce(x_mean) / self.group | |||||
| global_batch_mean_square = self.all_reduce(x_mean_square) / self.group | |||||
| global_mean = global_batch_mean | |||||
| global_var = global_batch_mean_square - self.square(global_mean) | |||||
| var_sqrt = self.sqrt(global_var + self.eps) | |||||
| mean_first = (x - global_mean) / var_sqrt | |||||
| y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) | |||||
| mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) | |||||
| tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) | |||||
| mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) | |||||
| tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) | |||||
| y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) | |||||
| y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) | |||||
| else: | else: | ||||
| axes = (0,) | axes = (0,) | ||||
| re_shape = (1, self.num_features) | re_shape = (1, self.num_features) | ||||
| x_mean = self.reduce_mean(x, axes) | |||||
| x_mean_square = self.reduce_mean(self.square(x), axes) | |||||
| global_batch_mean = self.all_reduce(x_mean) / self.group | |||||
| global_batch_mean_square = self.all_reduce(x_mean_square) / self.group | |||||
| global_mean = global_batch_mean | |||||
| global_var = global_batch_mean_square - self.square(global_mean) | |||||
| var_sqrt = self.sqrt(global_var + self.eps) | |||||
| mean_first = (x - global_mean) / var_sqrt | |||||
| y = mean_first * self.gamma + self.beta | |||||
| mean_sub = self.sub_mean(self.moving_mean, global_mean) | |||||
| temp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) | |||||
| mean_sub2 = self.sub_var(self.moving_variance, global_var) | |||||
| temp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) | |||||
| y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), temp_mean)) | |||||
| y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), temp_variance)) | |||||
| return axes, re_shape | |||||
| def _global_sync(self, x, axes, re_shape): | |||||
| """calculate global batch normalization output""" | |||||
| axes = (0, 2, 3) | |||||
| re_shape = (1, self.num_features, 1, 1) | |||||
| x_mean = self.reduce_mean(x, axes) | |||||
| x_mean_square = self.reduce_mean(self.square(x), axes) | |||||
| global_batch_mean = self.all_reduce(x_mean) / self.group | |||||
| global_batch_mean_square = self.all_reduce(x_mean_square) / self.group | |||||
| global_mean = global_batch_mean | |||||
| global_var = global_batch_mean_square - self.square(global_mean) | |||||
| var_sqrt = self.sqrt(global_var + self.eps) | |||||
| mean_first = (x - global_mean) / var_sqrt | |||||
| y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) | |||||
| mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) | |||||
| tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) | |||||
| mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) | |||||
| tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) | |||||
| y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) | |||||
| y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) | |||||
| return y | return y | ||||
| def construct(self, x): | def construct(self, x): | ||||
| if self.training and self.use_batch_statistics: | if self.training and self.use_batch_statistics: | ||||
| if self.is_ge_backend: | if self.is_ge_backend: | ||||
| if self.is_global: | if self.is_global: | ||||
| y = self._global_sync(x) | |||||
| axes, re_shape = self._shape_infer(x) | |||||
| y = self._global_sync(x, axes, re_shape) | |||||
| else: | else: | ||||
| y, batch_mean, batch_var, _, _ = \ | y, batch_mean, batch_var, _, _ = \ | ||||
| self.bn_train(x, | self.bn_train(x, | ||||