Browse Source

!20605 disable mindRT in control flow

Merge pull request !20605 from limingqi107/bug_fix3
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
024856a006
4 changed files with 47 additions and 4 deletions
  1. +8
    -0
      mindspore/ccsrc/backend/session/gpu_session.cc
  2. +1
    -1
      mindspore/ccsrc/debug/debugger/debugger.cc
  3. +9
    -3
      mindspore/ccsrc/frontend/parallel/group_manager.cc
  4. +29
    -0
      mindspore/ccsrc/pipeline/jit/action.cc

+ 8
- 0
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -94,6 +94,7 @@ namespace gpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using CollectiveInitializer = device::gpu::CollectiveInitializer;
using GetLocalRankId = device::gpu::GetLocalRankId;
using InitNCCLComm = device::gpu::InitNCCLComm;

void GPUSession::Init(uint32_t device_id) {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
@@ -113,7 +114,14 @@ void GPUSession::Init(uint32_t device_id) {
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
if (collective_inited) {
rank_id_ = GetRankId();
if (collective_handle_ != nullptr) {
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
}
}

auto &json_parser = DumpJsonParser::GetInstance();
// Dump json config file if dump is enabled
json_parser.CopyJsonToDir(rank_id_);


+ 1
- 1
mindspore/ccsrc/debug/debugger/debugger.cc View File

@@ -328,7 +328,7 @@ void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
graph_ptr_ = nullptr;
CheckGraphPtr(graph_ptr);
}
} else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
} else if (debugger_enabled_ && graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
// Multiple graph, and not the initial step,
// stop only when receive the first sub run graph for each step
// if we have stopped for the last kernel before, no need to stop again


+ 9
- 3
mindspore/ccsrc/frontend/parallel/group_manager.cc View File

@@ -73,7 +73,9 @@ GroupManager::GroupManager() { groups_.clear(); }
#if !defined(NO_DLIB) || defined(ENABLE_GPU)
bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
const std::vector<uint32_t> ranks, int device_id) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
// The group operation thread must be same with nccl init thread in the GPU device.
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
(MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
return CommManager::GetInstance().CreateGroupSync(group_name, ranks);
} else {
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
@@ -84,7 +86,9 @@ bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const s

bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
int device_id) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
// The group operation thread must be same with nccl init thread in the GPU device.
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
(MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
return CommManager::GetInstance().DestroyGroup(group_name);
} else {
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
@@ -103,7 +107,9 @@ Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_
MS_EXCEPTION_IF_NULL(executor);
for (auto &group : group_info) {
bool ret = true;
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
// The group operation thread must be same with nccl init thread in the GPU device.
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
(context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
ret = CommManager::GetInstance().CreateGroupSync(group.first, group.second);
} else {
ret = executor->CreateCommGroup(group.first, group.second);


+ 29
- 0
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -55,6 +55,33 @@
namespace mindspore {
namespace pipeline {
namespace {
// Disable mindRT in the control flow scenario.
void ResetMindRTEnable(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == false) {
return;
}

auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph != nullptr && func_graph->manager() != nullptr) {
auto manager = func_graph->manager();
size_t graph_nums = manager->func_graphs().size();
if (graph_nums == 1) {
return;
}

MS_LOG(INFO) << "Disable mindRT in the multi graphs scenario.";
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
// Update the backend.
auto new_backend = compile::CreateBackend();
new_backend->SetDebugger();
res->results()[kBackend] = new_backend;
}
}

void TaskEmitActionForMindRT(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
// Get the mindRT backend.
@@ -544,6 +571,8 @@ bool TaskEmitAction(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "TaskEmit args error";
}
// Disable mindRT in the control flow scenario.
ResetMindRTEnable(res);
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();


Loading…
Cancel
Save