Browse Source

add global batch normalization

tags/v0.2.0-alpha
zhaojichen 6 years ago
parent
commit
f7872774f3
2 changed files with 2 additions and 3 deletions
  1. +1
    -1
      mindspore/nn/layer/normalization.py
  2. +1
    -2
      tests/ut/python/nn/test_batchnorm.py

+ 1
- 1
mindspore/nn/layer/normalization.py View File

@@ -79,7 +79,7 @@ class _BatchNorm(Cell):
if self.rank_id in self.rank_list[i] and self.group != 1:
self.is_global = True
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = _GlobalBNHelper('group' + str(i))
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean()
self.square = P.Square()


+ 1
- 2
tests/ut/python/nn/test_batchnorm.py View File

@@ -90,5 +90,4 @@ def test_global_bn():
device_num=size, parameter_broadcast=True)
net = GlobalBNNet()
input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32))
net.set_train()
out = net(input_data)
_executor.compile(net,input_data)

Loading…
Cancel
Save