diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index dddf32ec48..2b55147cf1 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype import mindspore.context as context from mindspore._checkparam import check_bool, check_typename from mindspore._extends import cell_attr_register -from mindspore.communication.management import get_local_rank_size, get_rank +from mindspore.communication.management import get_group_size, get_rank from mindspore.communication import management from mindspore._checkparam import check_int_positive from ..cell import Cell @@ -71,7 +71,7 @@ class _BatchNorm(Cell): self.group = check_int_positive(group) if self.group != 1: self.rank_id = get_rank() - self.rank_size = get_local_rank_size() + self.rank_size = get_group_size() self.device_list = [i for i in range(0, self.rank_size)] self.rank_list = self.list_group(self.device_list, self.group) self.rank_list_idx = len(self.rank_list) diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index 8dac167a3f..b684df5263 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -65,6 +65,14 @@ def get_rank_size(group=None): return int(group.split("-")[0]) raise ValueError +def get_group_size(group=None): + hccl = Hccl() + if group is None: + return hccl.rank_size + if isinstance(group, str): + return int(group.split("-")[0]) + raise ValueError + # pylint: disable=unused-argument def get_world_rank_from_group_rank(group, group_rank_id): return group_rank_id