diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index f9316c7c11..9d623bc6fd 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -74,7 +74,7 @@ class _BatchNorm(Cell): management.create_group('group' + str(i), self.rank_list[i]) self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) self.shape = P.Shape() - self.reduce_mean = P.ReduceMean() + self.reduce_mean = P.ReduceMean(keep_dims=True) self.square = P.Square() self.sqrt = P.Sqrt() self.cast = P.Cast()