Browse Source

Fix profiling device id error

tags/v1.2.0-rc1
xiefangqi 4 years ago
parent
commit
b9f45b49ff
2 changed files with 20 additions and 9 deletions
  1. +8
    -4
      mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc
  2. +12
    -5
      mindspore/dataset/core/config.py

+ 8
- 4
mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc View File

@@ -19,7 +19,8 @@
#include <cstdlib>
#include <fstream>
#include "utils/ms_utils.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/perf/monitor.h"
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
#include "minddata/dataset/engine/perf/connector_size.h"
@@ -27,6 +28,7 @@
#include "minddata/dataset/engine/perf/cpu_sampling.h"
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
#include "minddata/dataset/util/log_adapter.h"
#include "minddata/dataset/util/path.h"

namespace mindspore {
namespace dataset {
@@ -59,11 +61,13 @@ Status ProfilingManager::Initialize() {
#endif
dir_path_ = real_path;

std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t rank_id = cfg->rank_id();
// If DEVICE_ID is not set, default value is 0
device_id_ = common::GetEnv("DEVICE_ID");
if (device_id_.empty()) {
device_id_ = "0";
if (rank_id < 0) {
rank_id = 0;
}
device_id_ = std::to_string(rank_id);

// Register all profiling node.
// device_queue node is used for graph mode


+ 12
- 5
mindspore/dataset/core/config.py View File

@@ -58,14 +58,21 @@ def _init_device_info():
# Ascend is a special scenario, we'd better get rank info from env
env_rank_size = os.getenv("RANK_SIZE", None)
env_rank_id = os.getenv("RANK_ID", None)
rank_size = 1
rank_id = 0
if env_rank_size and env_rank_id:
# Ascend only support multi-process scenario
rank_size = int(env_rank_size.strip())
rank_id = int(env_rank_id.strip())
if rank_size > 1:
if numa_enable:
_config.set_numa_enable(True)
_config.set_rank_id(rank_id)
if rank_size > 1:
if numa_enable:
_config.set_numa_enable(True)
_config.set_rank_id(rank_id)
else:
rank_id = 0
env_rank_id = os.getenv("DEVICE_ID", None)
if env_rank_id:
rank_id = int(env_rank_id.strip())
_config.set_rank_id(rank_id)


def set_seed(seed):


Loading…
Cancel
Save