Browse Source

get rank id when set hccl env for single card train

r1.4
yelihua 4 years ago
parent
commit
425b1168ad
6 changed files with 48 additions and 10 deletions
  1. +3
    -2
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/gpu_session.cc
  3. +1
    -0
      mindspore/ccsrc/backend/session/session_basic.cc
  4. +6
    -5
      mindspore/ccsrc/debug/data_dump/dump_json_parser.cc
  5. +4
    -2
      mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc
  6. +33
    -0
      tests/st/dump/test_data_dump.py

+ 3
- 2
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -863,9 +863,10 @@ void AscendSession::InitRuntimeResource() {
if (!runtime_instance->Init()) {
MS_LOG(EXCEPTION) << "Kernel runtime init error.";
}
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case.
rank_id_ = GetRankId();
}


+ 1
- 1
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -114,12 +114,12 @@ void GPUSession::Init(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
if (collective_inited) {
rank_id_ = GetRankId();
if (collective_handle_ != nullptr) {
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
rank_id_ = GetRankId();
}
}



+ 1
- 0
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -2659,6 +2659,7 @@ uint32_t GetRankId() {
world_group = kNcclWorldGroup;
} else {
MS_LOG(ERROR) << "Invalid backend: " << backend;
return rank_id;
}
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id.";


+ 6
- 5
mindspore/ccsrc/debug/data_dump/dump_json_parser.cc View File

@@ -351,7 +351,7 @@ void DumpJsonParser::ParseIteration(const nlohmann::json &content) {
MS_LOG(EXCEPTION) << "iteration only supports digits, {'-', '|'}, or just \"all\" but got: " << iteration_;
}
} else if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice) {
MS_LOG(WARNING) << "Dump not enabled. ";
MS_LOG(WARNING) << "Dump is not enabled. ";
} else {
MS_LOG(EXCEPTION) << "Dump Json Parse Failed. Async or E2E should be enabled. ";
}
@@ -486,14 +486,14 @@ void DumpJsonParser::JudgeDumpEnabled() {
}

if (!async_dump_enabled_ && !e2e_dump_enabled_) {
MS_LOG(WARNING) << "Dump json parse failed. Dump not enabled";
MS_LOG(WARNING) << "Dump json parse failed. Dump is not enabled";
}
if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
auto device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
if (support_devices_.find(device_id) == support_devices_.end()) {
async_dump_enabled_ = false;
e2e_dump_enabled_ = false;
MS_LOG(WARNING) << "Dump not enabled. device_id:" << device_id << " not support";
MS_LOG(WARNING) << "Dump is not enabled. device_id:" << device_id << " not support";
}
}
JsonConfigToString();
@@ -534,9 +534,10 @@ std::string DumpJsonParser::GetOpOverflowBinPath(uint32_t graph_id) const {
bin_path.append("rank_");

uint32_t rank_id = 0;
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case.
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id.";


+ 4
- 2
mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc View File

@@ -133,9 +133,11 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
}
uint32_t graph_id = kernel_graph_->graph_id();
uint32_t rank_id = 0;
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");

auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case.
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id.";


+ 33
- 0
tests/st/dump/test_data_dump.py View File

@@ -119,6 +119,17 @@ def test_e2e_dump():
run_e2e_dump()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_e2e_dump_with_hccl_env():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
os.environ["RANK_ID"] = "4"
run_e2e_dump()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -126,6 +137,17 @@ def test_cpu_e2e_dump():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
run_e2e_dump()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_e2e_dump_with_hccl_set():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
os.environ["RANK_ID"] = "4"
run_e2e_dump()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@@ -133,6 +155,17 @@ def test_gpu_e2e_dump():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
run_e2e_dump()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_e2e_dump_with_hccl_set():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
os.environ["RANK_ID"] = "4"
run_e2e_dump()


class ReluReduceMeanDenseRelu(Cell):
def __init__(self, kernel, bias, in_channel, num_class):
super().__init__()


Loading…
Cancel
Save