From eb46dd9198b358a8fac4fbceff260f6a363f3b8a Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 28 Apr 2020 09:56:28 -0400 Subject: [PATCH] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index f9316c7c11..9d623bc6fd 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -74,7 +74,7 @@ class _BatchNorm(Cell): management.create_group('group' + str(i), self.rank_list[i]) self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) self.shape = P.Shape() - self.reduce_mean = P.ReduceMean() + self.reduce_mean = P.ReduceMean(keep_dims=True) self.square = P.Square() self.sqrt = P.Sqrt() self.cast = P.Cast()