diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index ac44eb9cd1..904ad83aae 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -41,7 +41,7 @@ def _init_device_info(): """ from mindspore import context from mindspore.parallel._auto_parallel_context import auto_parallel_context - from mindspore.parallel._utils import _get_global_rank, _get_device_num + from mindspore.parallel._utils import _get_global_rank if context.get_context("device_target") == "GPU": rank_id = _get_global_rank() parallel_mode = auto_parallel_context().get_parallel_mode() @@ -53,11 +53,15 @@ def _init_device_info(): rank_id = cuda_id _config.set_rank_id(rank_id) elif context.get_context("device_target") == "Ascend": - rank_id = _get_global_rank() - device_num = _get_device_num() - # Ascend only support multi-process scenario - if device_num > 1: - _config.set_rank_id(rank_id) + # Ascend is a special scenario, we'd better get rank info from env + env_rank_size = os.getenv("RANK_SIZE", None) + env_rank_id = os.getenv("RANK_ID", None) + if env_rank_size and env_rank_id: + # Ascend only support multi-process scenario + rank_size = int(env_rank_size.strip()) + rank_id = int(env_rank_id.strip()) + if rank_size > 1: + _config.set_rank_id(rank_id) def set_seed(seed):