From b9f45b49ffa074a6572bab8e026d08f35a33276f Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Wed, 17 Mar 2021 14:50:22 +0800 Subject: [PATCH] Fix profiling device id error --- .../minddata/dataset/engine/perf/profiling.cc | 12 ++++++++---- mindspore/dataset/core/config.py | 17 ++++++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc index 2f5ee8e7eb..cea23b9857 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc @@ -19,7 +19,8 @@ #include #include #include "utils/ms_utils.h" -#include "minddata/dataset/util/path.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/perf/monitor.h" #include "minddata/dataset/engine/perf/device_queue_tracing.h" #include "minddata/dataset/engine/perf/connector_size.h" @@ -27,6 +28,7 @@ #include "minddata/dataset/engine/perf/cpu_sampling.h" #include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" #include "minddata/dataset/util/log_adapter.h" +#include "minddata/dataset/util/path.h" namespace mindspore { namespace dataset { @@ -59,11 +61,13 @@ Status ProfilingManager::Initialize() { #endif dir_path_ = real_path; + std::shared_ptr cfg = GlobalContext::config_manager(); + int32_t rank_id = cfg->rank_id(); // If DEVICE_ID is not set, default value is 0 - device_id_ = common::GetEnv("DEVICE_ID"); - if (device_id_.empty()) { - device_id_ = "0"; + if (rank_id < 0) { + rank_id = 0; } + device_id_ = std::to_string(rank_id); // Register all profiling node. // device_queue node is used for graph mode diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index 1a7eb3f56d..56a99a6881 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -58,14 +58,21 @@ def _init_device_info(): # Ascend is a special scenario, we'd better get rank info from env env_rank_size = os.getenv("RANK_SIZE", None) env_rank_id = os.getenv("RANK_ID", None) + rank_size = 1 + rank_id = 0 if env_rank_size and env_rank_id: - # Ascend only support multi-process scenario rank_size = int(env_rank_size.strip()) rank_id = int(env_rank_id.strip()) - if rank_size > 1: - if numa_enable: - _config.set_numa_enable(True) - _config.set_rank_id(rank_id) + if rank_size > 1: + if numa_enable: + _config.set_numa_enable(True) + _config.set_rank_id(rank_id) + else: + rank_id = 0 + env_rank_id = os.getenv("DEVICE_ID", None) + if env_rank_id: + rank_id = int(env_rank_id.strip()) + _config.set_rank_id(rank_id) def set_seed(seed):