From 30cbda6aaa1ddb736238c5372452b2d875a51b69 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Sun, 27 Dec 2020 16:59:54 +0800 Subject: [PATCH] Fix rank id runtime error issue --- mindspore/dataset/core/config.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index ac44eb9cd1..904ad83aae 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -41,7 +41,7 @@ def _init_device_info(): """ from mindspore import context from mindspore.parallel._auto_parallel_context import auto_parallel_context - from mindspore.parallel._utils import _get_global_rank, _get_device_num + from mindspore.parallel._utils import _get_global_rank if context.get_context("device_target") == "GPU": rank_id = _get_global_rank() parallel_mode = auto_parallel_context().get_parallel_mode() @@ -53,11 +53,15 @@ def _init_device_info(): rank_id = cuda_id _config.set_rank_id(rank_id) elif context.get_context("device_target") == "Ascend": - rank_id = _get_global_rank() - device_num = _get_device_num() - # Ascend only support multi-process scenario - if device_num > 1: - _config.set_rank_id(rank_id) + # Ascend is a special scenario, we'd better get rank info from env + env_rank_size = os.getenv("RANK_SIZE", None) + env_rank_id = os.getenv("RANK_ID", None) + if env_rank_size and env_rank_id: + # Ascend only support multi-process scenario + rank_size = int(env_rank_size.strip()) + rank_id = int(env_rank_id.strip()) + if rank_size > 1: + _config.set_rank_id(rank_id) def set_seed(seed):