Browse Source

fix globalbatchnorm bug

tags/v0.3.0-alpha
zhaojichen 5 years ago
parent
commit
eb46dd9198
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      mindspore/nn/layer/normalization.py

+ 1
- 1
mindspore/nn/layer/normalization.py View File

@@ -74,7 +74,7 @@ class _BatchNorm(Cell):
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean()
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.square = P.Square()
self.sqrt = P.Sqrt()
self.cast = P.Cast()


Loading…
Cancel
Save