|
|
|
@@ -41,7 +41,7 @@ def _init_device_info(): |
|
|
|
""" |
|
|
|
from mindspore import context |
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context |
|
|
|
from mindspore.parallel._utils import _get_global_rank, _get_device_num |
|
|
|
from mindspore.parallel._utils import _get_global_rank |
|
|
|
if context.get_context("device_target") == "GPU": |
|
|
|
rank_id = _get_global_rank() |
|
|
|
parallel_mode = auto_parallel_context().get_parallel_mode() |
|
|
|
@@ -53,11 +53,15 @@ def _init_device_info(): |
|
|
|
rank_id = cuda_id |
|
|
|
_config.set_rank_id(rank_id) |
|
|
|
elif context.get_context("device_target") == "Ascend": |
|
|
|
rank_id = _get_global_rank() |
|
|
|
device_num = _get_device_num() |
|
|
|
# Ascend only support multi-process scenario |
|
|
|
if device_num > 1: |
|
|
|
_config.set_rank_id(rank_id) |
|
|
|
# Ascend is a special scenario, we'd better get rank info from env |
|
|
|
env_rank_size = os.getenv("RANK_SIZE", None) |
|
|
|
env_rank_id = os.getenv("RANK_ID", None) |
|
|
|
if env_rank_size and env_rank_id: |
|
|
|
# Ascend only support multi-process scenario |
|
|
|
rank_size = int(env_rank_size.strip()) |
|
|
|
rank_id = int(env_rank_id.strip()) |
|
|
|
if rank_size > 1: |
|
|
|
_config.set_rank_id(rank_id) |
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed): |
|
|
|
|