| @@ -28,7 +28,9 @@ CollectiveManager::CollectiveManager() | |||
| host_ctx_(nullptr), | |||
| device_ctx_(nullptr), | |||
| host_comm_lib_(nullptr), | |||
| host_comm_lib_instance_(nullptr), | |||
| device_comm_lib_(nullptr), | |||
| device_comm_lib_instance_(nullptr), | |||
| global_rank_id_(0), | |||
| local_rank_id_(0), | |||
| global_rank_size_(0), | |||
| @@ -61,13 +63,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_); | |||
| // Step 2: Create global communication group on host side. | |||
| if (!CreateHostGlobalCommGroup(global_group_name)) { | |||
| // Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not | |||
| // be necessary. | |||
| // Step 2: Assign local rank id(device id) for this process. | |||
| if (!AssignLocalRank(global_group_name)) { | |||
| MS_LOG(ERROR) << "Failed to assign local rank id."; | |||
| return false; | |||
| } | |||
| // Step 3: Initialize device side collective communication. | |||
| if (!InitDeviceCommLib(backend)) { | |||
| MS_LOG(ERROR) << "Failed to initialize device communication library."; | |||
| return false; | |||
| } | |||
| // Step 4: Create global communication group. | |||
| if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) { | |||
| MS_LOG(ERROR) << "Failed to initialize host communication library."; | |||
| return false; | |||
| } | |||
| // Step 3: Assign local rank id(device id) for this process. | |||
| MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << "."; | |||
| return true; | |||
| @@ -75,25 +89,81 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string | |||
| bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name, | |||
| const std::vector<uint32_t> &group_ranks) { | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_); | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_); | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| // Step 1: Create communication group on host side. | |||
| // Step 2: Generate device information of the root node. | |||
| // Step 3: Broadcast the device root information to all nodes. | |||
| // Step 4: Create communication group on device side. | |||
| if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) { | |||
| MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side."; | |||
| return false; | |||
| } | |||
| // Step 2: Create communication group on device side. | |||
| if (!device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) { | |||
| MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on device side."; | |||
| return false; | |||
| } | |||
| // Step 3: Generate device information of the root node. | |||
| CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name); | |||
| MS_EXCEPTION_IF_NULL(group); | |||
| size_t root_info_size = 0; | |||
| void *root_info = group->GenerateRootInfo(&root_info_size); | |||
| MS_EXCEPTION_IF_NULL(root_info); | |||
| // Step 4: Broadcast the device root information to all nodes on host side. | |||
| if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt, 0, group_name, | |||
| nullptr)) { | |||
| MS_LOG(ERROR) << "Broadcast for device root info failed on the host side."; | |||
| return false; | |||
| } | |||
| // Step 5: Initialize communication group on the device side. | |||
| if (!group->Initialize(root_info)) { | |||
| MS_LOG(ERROR) << "Initialize group on the device side failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) { return true; } | |||
| bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) { | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| if (!host_comm_lib_instance_->DestroyCommunicationGroup(group_name)) { | |||
| MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the host side."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| if (!device_comm_lib_instance_->DestroyCommunicationGroup(group_name)) { | |||
| MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the device side."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| uint32_t CollectiveManager::GetRankId(const std::string &group_name) { return 0; } | |||
| uint32_t CollectiveManager::GetRankId(const std::string &group_name) { | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| return host_comm_lib_instance_->GetRankId(group_name); | |||
| } | |||
| uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) { return 0; } | |||
| uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) { | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| return host_comm_lib_instance_->GetGroupSize(group_name); | |||
| } | |||
| bool CollectiveManager::Finalize() { | |||
| if (finalized_) { | |||
| return true; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| if (!host_comm_lib_instance_->Finalize()) { | |||
| MS_LOG(WARNING) << "Failed to finalize host communication library."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| if (!device_comm_lib_instance_->Finalize()) { | |||
| MS_LOG(WARNING) << "Failed to finalize device communication library."; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -105,19 +175,34 @@ bool CollectiveManager::InitHostCommlib() { | |||
| MS_LOG(ERROR) << "Failed to load communication library on the host side."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool CollectiveManager::CreateHostGlobalCommGroup(const std::string &global_group_name) { | |||
| host_comm_lib_ = host_ctx_->collective_comm_lib(); | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_); | |||
| if (global_group_ranks_.empty()) { | |||
| MS_LOG(ERROR) << "The global group rank list is empty."; | |||
| auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_); | |||
| host_comm_lib_instance_ = instance_func(); | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| // For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using | |||
| // MindSpore communication. For other communication libraries, global rank id and size is generated by itself, e.g., | |||
| // OpenMPI, and parameters 'global_rank_id_', 'global_rank_size_' will not be used. | |||
| MS_LOG(INFO) << "Start initializing communication library on host side..."; | |||
| if (!host_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) { | |||
| MS_LOG(ERROR) << "Failed to initialize communication library on host side."; | |||
| return false; | |||
| } | |||
| // Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks. | |||
| global_rank_id_ = host_comm_lib_instance_->global_rank_id(); | |||
| global_rank_size_ = host_comm_lib_instance_->global_rank_size(); | |||
| for (uint32_t i = 0; i < global_rank_size_; i++) { | |||
| global_group_ranks_.push_back(i); | |||
| } | |||
| MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_ | |||
| << ", global rank size: " << global_rank_size_; | |||
| return true; | |||
| } | |||
| bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t device_id) { | |||
| bool CollectiveManager::InitDeviceCommLib(const std::string &backend) { | |||
| std::string device_name; | |||
| if (backend == "nccl") { | |||
| device_name = "GPU"; | |||
| @@ -128,13 +213,68 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t d | |||
| return false; | |||
| } | |||
| device::DeviceContextKey device_key = {device_name, device_id}; | |||
| device::DeviceContextKey device_key = {device_name, local_rank_id_}; | |||
| device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key); | |||
| MS_EXCEPTION_IF_NULL(device_ctx_); | |||
| if (!device_ctx_->LoadCollectiveCommLib()) { | |||
| MS_LOG(ERROR) << "Failed to load communication library on the device side."; | |||
| return false; | |||
| } | |||
| device_comm_lib_ = device_ctx_->collective_comm_lib(); | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_); | |||
| auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_); | |||
| device_comm_lib_instance_ = instance_func(); | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| MS_LOG(INFO) << "Start initializing communication library on device side..."; | |||
| if (!device_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) { | |||
| MS_LOG(ERROR) << "Failed to initialize communication library on device side."; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Communication library on device side is successfully initialized."; | |||
| return true; | |||
| } | |||
| bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { | |||
| char host_name[MAX_HOSTNAME_LEN] = {0}; | |||
| #ifndef _WIN32 | |||
| if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) { | |||
| MS_LOG(ERROR) << "Failed to get host name."; | |||
| return false; | |||
| } | |||
| #endif | |||
| MS_LOG(INFO) << "Host name for rank " << global_rank_id_ << " is " << host_name; | |||
| // Generate host name hash for every process. The host names of different physical machine should not be the same so | |||
| // that local rank id won't repeat. | |||
| size_t host_hash = std::hash<std::string>()(host_name); | |||
| const uint32_t kGlobalRankSize = global_rank_size_; | |||
| size_t all_host_hashs[kGlobalRankSize]; | |||
| if (global_rank_id_ >= global_rank_size_) { | |||
| MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size " | |||
| << global_rank_size_; | |||
| return false; | |||
| } | |||
| all_host_hashs[global_rank_id_] = host_hash; | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| // AllGather host names across the global communication group. | |||
| if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt, | |||
| global_group_name, nullptr)) { | |||
| MS_LOG(ERROR) << "AllGather for host names failed."; | |||
| return false; | |||
| } | |||
| // Accumulate rank id. | |||
| for (uint32_t rank = 0; rank < global_rank_size_; rank++) { | |||
| if (rank == global_rank_id_) { | |||
| break; | |||
| } | |||
| if (all_host_hashs[rank] == all_host_hashs[global_rank_id_]) { | |||
| local_rank_id_++; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_; | |||
| return true; | |||
| } | |||
| } // namespace collective | |||
| @@ -22,6 +22,7 @@ | |||
| #include <vector> | |||
| #include <atomic> | |||
| #include "utils/ms_utils.h" | |||
| #include "distributed/constants.h" | |||
| #include "runtime/hardware/device_context_manager.h" | |||
| namespace mindspore { | |||
| @@ -30,6 +31,8 @@ namespace collective { | |||
| using DeviceContext = device::DeviceContext; | |||
| using DeviceContextKey = device::DeviceContextKey; | |||
| using DeviceContextManager = device::DeviceContextManager; | |||
| using CollectiveCommunicationLib = device::CollectiveCommunicationLib; | |||
| using CommunicationGroupPtr = device::CommunicationGroupPtr; | |||
| // The collective communication API. | |||
| // MindSpore uses OpenMPI on CPU, NCCL on GPU, HCCL on Ascend, to achieve distributed training. | |||
| @@ -43,6 +46,9 @@ class CollectiveManager { | |||
| // Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL. | |||
| bool Initialize(const std::string &backend, const std::string &global_group_name); | |||
| // Finalize the collective communication. | |||
| bool Finalize(); | |||
| // Create communication group. | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks); | |||
| @@ -55,8 +61,10 @@ class CollectiveManager { | |||
| // Get the size of the specified group. | |||
| uint32_t GetGroupSize(const std::string &group_name); | |||
| // Finalize the collective communication. | |||
| bool Finalize(); | |||
| // In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication | |||
| // framework, they're generated by cluster::ClusterContext. | |||
| uint32_t set_global_rank_id(); | |||
| uint32_t set_global_rank_size(); | |||
| private: | |||
| CollectiveManager(); | |||
| @@ -64,14 +72,11 @@ class CollectiveManager { | |||
| // Initialize communication library on host side. | |||
| bool InitHostCommlib(); | |||
| // Create world communication group on the host side. | |||
| bool CreateHostGlobalCommGroup(const std::string &global_group_name); | |||
| // Initialize communication library on device side. | |||
| bool InitDeviceCommLib(const std::string &backend, uint32_t device_id); | |||
| bool InitDeviceCommLib(const std::string &backend); | |||
| // Create world communication group on the device side. | |||
| bool CreateDeviceGlobalCommGroup(const std::string &global_group_name); | |||
| // Assign the local rank id for this process. | |||
| bool AssignLocalRank(const std::string &global_group_name); | |||
| std::atomic_bool inited_; | |||
| std::atomic_bool finalized_; | |||
| @@ -81,9 +86,15 @@ class CollectiveManager { | |||
| DeviceContext *host_ctx_; | |||
| DeviceContext *device_ctx_; | |||
| // The dynamically loaded handle for collective communication library by 'dlopen'. | |||
| // Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication | |||
| // framework. | |||
| void *host_comm_lib_; | |||
| CollectiveCommunicationLib *host_comm_lib_instance_; | |||
| // Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL. | |||
| // When only CPU backend is used, device communication library should not be initialized. | |||
| void *device_comm_lib_; | |||
| CollectiveCommunicationLib *device_comm_lib_instance_; | |||
| // The global rank id of this process. Normally this range is 0 to `total process number - 1`. | |||
| uint32_t global_rank_id_; | |||
| @@ -34,6 +34,7 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; | |||
| const std::set<std::string> kValidRoleName = {kEnvRoleOfServer, kEnvRoleOfWorker, kEnvRoleOfScheduler}; | |||
| constexpr char kLocalHost[] = "127.0.0.1"; | |||
| constexpr int MAX_HOSTNAME_LEN = 1024; | |||
| const uint16_t kDefaultSchedPort = 6667; | |||
| const uint16_t kMaxPort = 65535; | |||
| } // namespace distributed | |||
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "utils/dlopen_macro.h" | |||
| #include "runtime/hardware/collective/collective_communication_lib.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -51,17 +52,6 @@ using CollectiveCommLibLoaderPtr = std::shared_ptr<CollectiveCommLibLoader>; | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #ifndef _WIN32 | |||
| // The exported symbols of collective communication shared library is registered here. | |||
| ORIGIN_METHOD(InitializeCollectiveLib, bool, uint32_t, uint32_t) | |||
| ORIGIN_METHOD(FinalizeCollectiveLib, bool) | |||
| ORIGIN_METHOD(CreateCommunicationGroup, bool, const std::string &, const std::vector<uint32_t> &) | |||
| ORIGIN_METHOD(DestroyCommunicationGroup, bool, const std::string &) | |||
| ORIGIN_METHOD(GetRankId, uint32_t, const std::string &) | |||
| ORIGIN_METHOD(GetCommunicationGroupSize, uint32_t, const std::string &) | |||
| ORIGIN_METHOD(AssignLocalRank, bool) | |||
| ORIGIN_METHOD(global_rank_id, uint32_t) | |||
| ORIGIN_METHOD(local_rank_id, uint32_t) | |||
| ORIGIN_METHOD(global_rank_size, uint32_t) | |||
| #endif | |||
| ORIGIN_METHOD(communication_lib_instance, mindspore::device::CollectiveCommunicationLib *) | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_ | |||
| @@ -55,6 +55,11 @@ uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name) | |||
| return group->group_size(); | |||
| } | |||
| CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &group_name) { | |||
| CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist."); | |||
| return groups_[group_name]; | |||
| } | |||
| uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; } | |||
| uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; } | |||
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "runtime/hardware/collective/communication_group.h" | |||
| namespace mindspore { | |||
| @@ -61,6 +62,21 @@ class CollectiveCommunicationLib { | |||
| // Assign the local rank id for this process. Normally used by collective communication library on the host side. | |||
| virtual bool AssignLocalRank() { return true; } | |||
| // Return communication group pointer. | |||
| virtual CommunicationGroupPtr GetGroup(const std::string &group_name); | |||
| // Primitive of AllGather operation. | |||
| virtual bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &group_name, void *stream) { | |||
| return true; | |||
| } | |||
| // Primitive of Broadcast operation. | |||
| virtual bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| uint32_t root_rank, const std::string &group_name, void *stream) { | |||
| return true; | |||
| } | |||
| // Returns global rank id of this process. | |||
| uint32_t global_rank_id() const; | |||
| @@ -44,10 +44,9 @@ class CommunicationGroup { | |||
| // Finalize the communication group. For example, destroy the group, etc. | |||
| virtual bool Finalize() = 0; | |||
| // Return the root rank's information. Only root rank of one group could call this method.Normally this is used for | |||
| // collective libraries on the device side. For NCCL group, it returns 'ncclUniqueId'. For HCCL group, it returns | |||
| // 'HcclRootInfo'. | |||
| virtual void *GenerateRootInfo() { return nullptr; } | |||
| // Return the root rank's information and its size. Normally this is used for collective libraries on the device side. | |||
| // For NCCL group, it returns a pointer to 'ncclUniqueId'. For HCCL group, it returns a pointer to 'HcclRootInfo'. | |||
| virtual void *GenerateRootInfo(size_t *root_info_size) { return nullptr; } | |||
| // Get group or global rank for the given rank. | |||
| uint32_t GetGroupRank(uint32_t global_rank); | |||
| @@ -55,11 +55,12 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam | |||
| return true; | |||
| } | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| // The exported APIs for 'dlsym' to load. | |||
| using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib; | |||
| CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); } | |||
| bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); } | |||
| bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); } | |||
| @@ -80,8 +81,24 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) { | |||
| bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); } | |||
| CommunicationGroupPtr GetGroup(const std::string &group_name) { | |||
| return MPICollectiveCommLib::GetInstance().GetGroup(group_name); | |||
| } | |||
| bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &group_name, void *stream) { | |||
| return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream); | |||
| } | |||
| bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, | |||
| uint32_t root_rank, const std::string &group_name, void *stream) { | |||
| return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank, | |||
| group_name, stream); | |||
| } | |||
| uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); } | |||
| uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); } | |||
| uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -38,6 +38,16 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib { | |||
| // Override creating method. Reuse destroying method in base class CollectiveCommunicationLib. | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override; | |||
| bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &group_name, void *stream) override { | |||
| return true; | |||
| } | |||
| bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank, | |||
| const std::string &group_name, void *stream) override { | |||
| return true; | |||
| } | |||
| private: | |||
| MPICollectiveCommLib() = default; | |||
| ~MPICollectiveCommLib() override = default; | |||
| @@ -45,12 +55,11 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib { | |||
| MPI_Group world_group_; | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #ifndef EXPORT_MPI_WRAPPER | |||
| #define EXPORT_MPI_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); | |||
| extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX, | |||
| uint32_t global_rank_size = UINT32_MAX); | |||
| extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib(); | |||
| @@ -60,7 +69,15 @@ extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string & | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name); | |||
| extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank(); | |||
| extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name); | |||
| extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, uint32_t root_rank, | |||
| const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id(); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id(); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size(); | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_ | |||
| @@ -41,11 +41,11 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_ | |||
| return true; | |||
| } | |||
| } // namespace gpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| // The exported APIs for 'dlsym' to load. | |||
| using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib; | |||
| CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); } | |||
| bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) { | |||
| return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size); | |||
| } | |||
| @@ -70,4 +70,22 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) { | |||
| bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); } | |||
| CommunicationGroupPtr GetGroup(const std::string &group_name) { | |||
| return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name); | |||
| } | |||
| bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, | |||
| const std::string &group_name, void *stream) { | |||
| return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, | |||
| stream); | |||
| } | |||
| bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, | |||
| uint32_t root_rank, const std::string &group_name, void *stream) { | |||
| return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank, | |||
| group_name, stream); | |||
| } | |||
| uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -39,17 +39,26 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override; | |||
| bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &group_name, void *stream) override { | |||
| return true; | |||
| } | |||
| bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank, | |||
| const std::string &group_name, void *stream) override { | |||
| return true; | |||
| } | |||
| private: | |||
| NvidiaCollectiveCommLib() = default; | |||
| ~NvidiaCollectiveCommLib() override = default; | |||
| }; | |||
| } // namespace gpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #ifndef EXPORT_NCCL_WRAPPER | |||
| #define EXPORT_NCCL_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX, | |||
| uint32_t global_rank_size = UINT32_MAX); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib(); | |||
| @@ -59,5 +68,13 @@ extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string | |||
| extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name); | |||
| extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank(); | |||
| extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, uint32_t root_rank, | |||
| const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id(); | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_ | |||
| @@ -21,7 +21,7 @@ namespace device { | |||
| namespace gpu { | |||
| NvidiaCommunicationGroup::NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks, | |||
| uint32_t global_rank) | |||
| : CommunicationGroup(name, group_ranks, global_rank) {} | |||
| : CommunicationGroup(name, group_ranks, global_rank), unique_id_({}), comm_(nullptr) {} | |||
| bool NvidiaCommunicationGroup::Initialize(void *root_info) { | |||
| if (initialized_) { | |||
| @@ -50,8 +50,12 @@ bool NvidiaCommunicationGroup::Finalize() { | |||
| return true; | |||
| } | |||
| void *NvidiaCommunicationGroup::GenerateRootInfo() { | |||
| CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id."); | |||
| void *NvidiaCommunicationGroup::GenerateRootInfo(size_t *root_info_size) { | |||
| *root_info_size = sizeof(unique_id_); | |||
| uint32_t group_rank = GetGroupRank(global_rank_); | |||
| if (group_rank == 0) { | |||
| CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id."); | |||
| } | |||
| return &unique_id_; | |||
| } | |||
| } // namespace gpu | |||
| @@ -37,7 +37,7 @@ class NvidiaCommunicationGroup : public CommunicationGroup { | |||
| bool Initialize(void *root_info) override; | |||
| bool Finalize() override; | |||
| void *GenerateRootInfo() override; | |||
| void *GenerateRootInfo(size_t *root_info_size) override; | |||
| private: | |||
| // The NCCL unique id for this group. Used to initialize this group's communicator. | |||