|
|
|
@@ -128,8 +128,6 @@ class _BatchNorm(Cell): |
|
|
|
|
|
|
|
def _global_sync(self, x, axes, re_shape): |
|
|
|
"""calculate global batch normalization output""" |
|
|
|
axes = (0, 2, 3) |
|
|
|
re_shape = (1, self.num_features, 1, 1) |
|
|
|
x_mean = self.reduce_mean(x, axes) |
|
|
|
x_mean_square = self.reduce_mean(self.square(x), axes) |
|
|
|
global_batch_mean = self.all_reduce(x_mean) / self.group |
|
|
|
|