|
|
|
@@ -73,7 +73,9 @@ GroupManager::GroupManager() { groups_.clear(); } |
|
|
|
#if !defined(NO_DLIB) || defined(ENABLE_GPU) |
|
|
|
bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name, |
|
|
|
const std::vector<uint32_t> ranks, int device_id) { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) { |
|
|
|
// The group operation thread must be same with nccl init thread in the GPU device. |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) || |
|
|
|
(MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) { |
|
|
|
return CommManager::GetInstance().CreateGroupSync(group_name, ranks); |
|
|
|
} else { |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
@@ -84,7 +86,9 @@ bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const s |
|
|
|
|
|
|
|
bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name, |
|
|
|
int device_id) { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) { |
|
|
|
// The group operation thread must be same with nccl init thread in the GPU device. |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) || |
|
|
|
(MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) { |
|
|
|
return CommManager::GetInstance().DestroyGroup(group_name); |
|
|
|
} else { |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
@@ -103,7 +107,9 @@ Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_ |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
for (auto &group : group_info) { |
|
|
|
bool ret = true; |
|
|
|
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) { |
|
|
|
// The group operation thread must be same with nccl init thread in the GPU device. |
|
|
|
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) || |
|
|
|
(context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) { |
|
|
|
ret = CommManager::GetInstance().CreateGroupSync(group.first, group.second); |
|
|
|
} else { |
|
|
|
ret = executor->CreateCommGroup(group.first, group.second); |
|
|
|
|