|
|
|
@@ -18,7 +18,11 @@ |
|
|
|
#include <algorithm> |
|
|
|
#include <vector> |
|
|
|
#include <utility> |
|
|
|
#if !defined(NO_DLIB) || defined(ENABLE_GPU) |
|
|
|
#include "backend/session/executor_manager.h" |
|
|
|
#else |
|
|
|
#include "frontend/parallel/parallel_stub/executor_manager_stub.h" |
|
|
|
#endif |
|
|
|
#include "frontend/parallel/device_manager.h" |
|
|
|
#include "utils/comm_manager.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
@@ -66,6 +70,79 @@ Status Group::GetIndex(size_t *index) { |
|
|
|
|
|
|
|
GroupManager::GroupManager() { groups_.clear(); } |
|
|
|
|
|
|
|
#if !defined(NO_DLIB) || defined(ENABLE_GPU) |
|
|
|
bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name, |
|
|
|
const std::vector<uint32_t> ranks, int device_id) { |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
bool ret = executor->CreateCommGroup(group_name, ranks); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name, |
|
|
|
int device_id) { |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
bool ret = executor->DestroyCommGroup(group_name); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) { |
|
|
|
// Create group through the executor |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
for (auto &group : group_info) { |
|
|
|
bool ret = executor->CreateCommGroup(group.first, group.second); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
#else |
|
|
|
bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name, |
|
|
|
const std::vector<uint32_t> ranks, int device_id) { |
|
|
|
MS_LOG(WARNING) << "Create group in stub"; |
|
|
|
auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
return executor->CreateCommGroup(group_name, ranks); |
|
|
|
} |
|
|
|
|
|
|
|
bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name, |
|
|
|
int device_id) { |
|
|
|
MS_LOG(WARNING) << "Destroy group in stub"; |
|
|
|
auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
return executor->DestroyCommGroup(group_name); |
|
|
|
} |
|
|
|
|
|
|
|
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) { |
|
|
|
// Create group through the executor |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
for (auto &group : group_info) { |
|
|
|
bool ret = executor->CreateCommGroup(group.first, group.second); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
#endif |
|
|
|
Status GroupManager::CreateGroup(const std::string &group_name, const std::vector<Device> &devices, |
|
|
|
mindspore::parallel::Group *const group) { |
|
|
|
// it is simple to use size to determine whether it is a world group |
|
|
|
@@ -102,9 +179,7 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
bool ret = executor->CreateCommGroup(group_name, ranks); |
|
|
|
bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(ERROR) << "Create group failed, group name is " << group_name; |
|
|
|
return Status::FAILED; |
|
|
|
@@ -123,9 +198,7 @@ Status GroupManager::DestroyGroup(const std::string &group_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
bool ret = executor->DestroyCommGroup(group_name); |
|
|
|
bool ret = DestroyGroupByExecutor(device_name, group_name, device_id); |
|
|
|
if (!ret) { |
|
|
|
return Status::FAILED; |
|
|
|
} |
|
|
|
@@ -192,26 +265,5 @@ Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Gro |
|
|
|
|
|
|
|
void GroupManager::Clear() { (void)DestroyAllGroups(); } |
|
|
|
|
|
|
|
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) { |
|
|
|
// Create group through the executor |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
|
|
|
|
for (auto &group : group_info) { |
|
|
|
bool ret = executor->CreateCommGroup(group.first, group.second); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace parallel |
|
|
|
} // namespace mindspore |