Browse Source

fix globalbatchnorm bug

tags/v0.3.0-alpha
zhaojichen 6 years ago
parent
commit
8261cfd019
1 changed files with 0 additions and 2 deletions
  1. +0
    -2
      mindspore/nn/layer/normalization.py

+ 0
- 2
mindspore/nn/layer/normalization.py View File

@@ -128,8 +128,6 @@ class _BatchNorm(Cell):

def _global_sync(self, x, axes, re_shape):
"""calculate global batch normalization output"""
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)
x_mean = self.reduce_mean(x, axes)
x_mean_square = self.reduce_mean(self.square(x), axes)
global_batch_mean = self.all_reduce(x_mean) / self.group


Loading…
Cancel
Save