|
|
|
@@ -110,9 +110,9 @@ class _BatchNorm(Cell): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def list_group(self, world_rank, group_size): |
|
|
|
if group_size > get_local_rank_size(): |
|
|
|
if group_size > get_group_size(): |
|
|
|
raise ValueError("group size can not be greater than local rank size, group size is {}, " |
|
|
|
"local_rank_size is {}".format(group_size, get_local_rank_size())) |
|
|
|
"local_rank_size is {}".format(group_size, get_group_size())) |
|
|
|
if len(world_rank) % group_size != 0: |
|
|
|
raise ValueError("please make your group size correct.") |
|
|
|
world_rank_list = zip(*(iter(world_rank),) *group_size) |
|
|
|
|