Browse Source

fix globalbatchnorm bug

tags/v0.3.0-alpha
zhaojichen 5 years ago
parent
commit
6c9a54afa1
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      mindspore/nn/layer/normalization.py

+ 1
- 1
mindspore/nn/layer/normalization.py View File

@@ -119,7 +119,7 @@ class _BatchNorm(Cell):
def _shape_infer(self, x):
"""global batch normalization shape and axes infer"""
if len(self.shape(x)) == 4:
axes = (0,2,3)
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)
else:
axes = (0,)


Loading…
Cancel
Save