Browse Source

add global batch normalization

tags/v0.2.0-alpha
zhaojichen 5 years ago
parent
commit
d2b04664ca
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      mindspore/nn/layer/normalization.py

+ 2
- 2
mindspore/nn/layer/normalization.py View File

@@ -110,9 +110,9 @@ class _BatchNorm(Cell):
raise NotImplementedError raise NotImplementedError


def list_group(self, world_rank, group_size): def list_group(self, world_rank, group_size):
if group_size > get_local_rank_size():
if group_size > get_group_size():
raise ValueError("group size can not be greater than local rank size, group size is {}, " raise ValueError("group size can not be greater than local rank size, group size is {}, "
"local_rank_size is {}".format(group_size, get_local_rank_size()))
"local_rank_size is {}".format(group_size, get_group_size()))
if len(world_rank) % group_size != 0: if len(world_rank) % group_size != 0:
raise ValueError("please make your group size correct.") raise ValueError("please make your group size correct.")
world_rank_list = zip(*(iter(world_rank),) *group_size) world_rank_list = zip(*(iter(world_rank),) *group_size)


Loading…
Cancel
Save