| @@ -85,7 +85,7 @@ bool ClusterContext::Finalize() { | |||
| return true; | |||
| } | |||
| std::string ClusterContext::node_role() const { return node_role_; } | |||
| const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; } | |||
| void ClusterContext::InitClusterConfig() { | |||
| InitNodeRole(); | |||
| @@ -49,7 +49,8 @@ class ClusterContext { | |||
| // Finalize the cluster and process exits. | |||
| bool Finalize(); | |||
| std::string node_role() const; | |||
| // Return node object of this process. | |||
| const std::shared_ptr<ps::core::Node> &node() const; | |||
| private: | |||
| ClusterContext(); | |||
| @@ -32,8 +32,6 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() { | |||
| bool ClusterContext::Initialize() const { return true; } | |||
| bool ClusterContext::Finalize() const { return true; } | |||
| std::string ClusterContext::node_role() const { return ""; } | |||
| } // namespace cluster | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||
| @@ -39,7 +39,6 @@ class ClusterContext { | |||
| bool Initialize() const; | |||
| bool Finalize() const; | |||
| std::string node_role() const; | |||
| private: | |||
| ClusterContext() = default; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace distributed { | |||
| @@ -27,9 +28,7 @@ CollectiveManager::CollectiveManager() | |||
| finalized_(true), | |||
| host_ctx_(nullptr), | |||
| device_ctx_(nullptr), | |||
| host_comm_lib_(nullptr), | |||
| host_comm_lib_instance_(nullptr), | |||
| device_comm_lib_(nullptr), | |||
| device_comm_lib_instance_(nullptr), | |||
| global_rank_id_(0), | |||
| local_rank_id_(0), | |||
| @@ -51,11 +50,13 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() { | |||
| return instance; | |||
| } | |||
| bool CollectiveManager::Initialize(const std::string &backend, const std::string &global_group_name) { | |||
| bool CollectiveManager::Initialize() { | |||
| if (inited_) { | |||
| return true; | |||
| } | |||
| MS_LOG(INFO) << "Start initializing collective communication for backend: " << backend << "..."; | |||
| device_type_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type_ << "..."; | |||
| // Step 1: Initialize host side collective communication. | |||
| if (!InitHostCommlib()) { | |||
| @@ -66,24 +67,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string | |||
| // Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not | |||
| // be necessary. | |||
| // Step 2: Assign local rank id(device id) for this process. | |||
| if (!AssignLocalRank(global_group_name)) { | |||
| if (!AssignLocalRank()) { | |||
| MS_LOG(ERROR) << "Failed to assign local rank id."; | |||
| return false; | |||
| } | |||
| // Step 3: Initialize device side collective communication. | |||
| if (!InitDeviceCommLib(backend)) { | |||
| if (!InitDeviceCommLib()) { | |||
| MS_LOG(ERROR) << "Failed to initialize device communication library."; | |||
| return false; | |||
| } | |||
| // Step 4: Create global communication group. | |||
| if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) { | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| if (!CreateCommunicationGroup(device_comm_lib_instance_->global_group_name(), global_group_ranks_)) { | |||
| MS_LOG(ERROR) << "Failed to initialize host communication library."; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << "."; | |||
| MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_; | |||
| return true; | |||
| } | |||
| @@ -91,7 +93,7 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name, | |||
| const std::vector<uint32_t> &group_ranks) { | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| // Step 1: Create communication group on host side. | |||
| // Step 1: Create communication group on host side if. | |||
| if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) { | |||
| MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side."; | |||
| return false; | |||
| @@ -167,6 +169,12 @@ bool CollectiveManager::Finalize() { | |||
| return true; | |||
| } | |||
| void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; } | |||
| void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; } | |||
| uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; } | |||
| bool CollectiveManager::InitHostCommlib() { | |||
| device::DeviceContextKey host_key = {"CPU", 0}; | |||
| host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key); | |||
| @@ -175,10 +183,7 @@ bool CollectiveManager::InitHostCommlib() { | |||
| MS_LOG(ERROR) << "Failed to load communication library on the host side."; | |||
| return false; | |||
| } | |||
| host_comm_lib_ = host_ctx_->collective_comm_lib(); | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_); | |||
| auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_); | |||
| host_comm_lib_instance_ = instance_func(); | |||
| host_comm_lib_instance_ = host_ctx_->collective_comm_lib(); | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| // For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using | |||
| @@ -197,33 +202,30 @@ bool CollectiveManager::InitHostCommlib() { | |||
| global_group_ranks_.push_back(i); | |||
| } | |||
| // Create world group on host side for AllGather operation of host name while assigning local rank. | |||
| host_global_group_name_ = host_comm_lib_instance_->global_group_name(); | |||
| if (!host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_)) { | |||
| MS_LOG(ERROR) << "Failed to create communication group " << host_global_group_name_ << " on host side."; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_ | |||
| << ", global rank size: " << global_rank_size_; | |||
| return true; | |||
| } | |||
| bool CollectiveManager::InitDeviceCommLib(const std::string &backend) { | |||
| std::string device_name; | |||
| if (backend == "nccl") { | |||
| device_name = "GPU"; | |||
| } else if (backend == "hccl") { | |||
| device_name = "Ascend"; | |||
| } else { | |||
| MS_LOG(ERROR) << "Backend " << backend << " is not supported."; | |||
| return false; | |||
| } | |||
| device::DeviceContextKey device_key = {device_name, local_rank_id_}; | |||
| bool CollectiveManager::InitDeviceCommLib() { | |||
| device::DeviceContextKey device_key = {device_type_, local_rank_id_}; | |||
| device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key); | |||
| MS_EXCEPTION_IF_NULL(device_ctx_); | |||
| // We can initialize device context now because device id(local_rank_id_) is already assigned. | |||
| device_ctx_->Initialize(); | |||
| if (!device_ctx_->LoadCollectiveCommLib()) { | |||
| MS_LOG(ERROR) << "Failed to load communication library on the device side."; | |||
| return false; | |||
| } | |||
| device_comm_lib_ = device_ctx_->collective_comm_lib(); | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_); | |||
| auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_); | |||
| device_comm_lib_instance_ = instance_func(); | |||
| device_comm_lib_instance_ = device_ctx_->collective_comm_lib(); | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| MS_LOG(INFO) << "Start initializing communication library on device side..."; | |||
| @@ -235,7 +237,7 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend) { | |||
| return true; | |||
| } | |||
| bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { | |||
| bool CollectiveManager::AssignLocalRank() { | |||
| char host_name[MAX_HOSTNAME_LEN] = {0}; | |||
| #ifndef _WIN32 | |||
| if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) { | |||
| @@ -259,8 +261,8 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| // AllGather host names across the global communication group. | |||
| if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt, | |||
| global_group_name, nullptr)) { | |||
| if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeInt, | |||
| host_global_group_name_, nullptr)) { | |||
| MS_LOG(ERROR) << "AllGather for host names failed."; | |||
| return false; | |||
| } | |||
| @@ -274,7 +276,10 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { | |||
| local_rank_id_++; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_; | |||
| MsContext::GetInstance()->set_param<uint32_t>(MS_CTX_DEVICE_ID, local_rank_id_); | |||
| MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_ | |||
| << ". device_id of ms_context is set."; | |||
| return true; | |||
| } | |||
| } // namespace collective | |||
| @@ -43,8 +43,8 @@ class CollectiveManager { | |||
| DISABLE_COPY_AND_ASSIGN(CollectiveManager); | |||
| static std::shared_ptr<CollectiveManager> instance(); | |||
| // Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL. | |||
| bool Initialize(const std::string &backend, const std::string &global_group_name); | |||
| // Initialize the collective communication for distributed training. The backend type is read from MindSpore context. | |||
| bool Initialize(); | |||
| // Finalize the collective communication. | |||
| bool Finalize(); | |||
| @@ -63,8 +63,10 @@ class CollectiveManager { | |||
| // In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication | |||
| // framework, they're generated by cluster::ClusterContext. | |||
| uint32_t set_global_rank_id(); | |||
| uint32_t set_global_rank_size(); | |||
| void set_global_rank_id(uint32_t global_rank_id); | |||
| void set_global_rank_size(uint32_t global_rank_size); | |||
| uint32_t local_rank_id() const; | |||
| private: | |||
| CollectiveManager(); | |||
| @@ -73,14 +75,17 @@ class CollectiveManager { | |||
| bool InitHostCommlib(); | |||
| // Initialize communication library on device side. | |||
| bool InitDeviceCommLib(const std::string &backend); | |||
| bool InitDeviceCommLib(); | |||
| // Assign the local rank id for this process. | |||
| bool AssignLocalRank(const std::string &global_group_name); | |||
| bool AssignLocalRank(); | |||
| std::atomic_bool inited_; | |||
| std::atomic_bool finalized_; | |||
| // The device type read from MindSpore context. | |||
| std::string device_type_; | |||
| // The device context on both host and device side. They are used to access the communication library on different | |||
| // devices. | |||
| DeviceContext *host_ctx_; | |||
| @@ -88,12 +93,10 @@ class CollectiveManager { | |||
| // Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication | |||
| // framework. | |||
| void *host_comm_lib_; | |||
| CollectiveCommunicationLib *host_comm_lib_instance_; | |||
| // Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL. | |||
| // When only CPU backend is used, device communication library should not be initialized. | |||
| void *device_comm_lib_; | |||
| CollectiveCommunicationLib *device_comm_lib_instance_; | |||
| // The global rank id of this process. Normally this range is 0 to `total process number - 1`. | |||
| @@ -107,6 +110,10 @@ class CollectiveManager { | |||
| // Global group ranks. | |||
| std::vector<uint32_t> global_group_ranks_; | |||
| // The global group name on the host side. This is used for Creating global group on host side for AllGather operation | |||
| // of host name while assigning local rank. | |||
| std::string host_global_group_name_; | |||
| }; | |||
| } // namespace collective | |||
| } // namespace distributed | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_ | |||
| #include <set> | |||
| #include <map> | |||
| #include <string> | |||
| namespace mindspore { | |||
| @@ -20,16 +20,29 @@ | |||
| namespace mindspore { | |||
| namespace distributed { | |||
| bool Initialize(const std::string &backend, const std::string &global_group_name) { | |||
| bool Initialize() { | |||
| if (!InitializeCluster()) { | |||
| MS_LOG(ERROR) << "Failed to initialize cluster."; | |||
| return false; | |||
| } | |||
| if (!InitializeCollective(backend, global_group_name)) { | |||
| MS_LOG(ERROR) << "Failed to initialize collective communication."; | |||
| return false; | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| // Server and Scheduler don't use collective communication library. | |||
| auto node = cluster::ClusterContext::instance()->node(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->role() != ps::core::NodeRole::SERVER && node->role() != ps::core::NodeRole::SCHEDULER) { | |||
| // Global rank id and size should be manually set if cluster is initialized by MindSpore communication framework. | |||
| auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node()); | |||
| MS_EXCEPTION_IF_NULL(abstract_node); | |||
| collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id()); | |||
| collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num()); | |||
| if (!InitializeCollective()) { | |||
| MS_LOG(ERROR) << "Failed to initialize collective communication."; | |||
| return false; | |||
| } | |||
| } | |||
| #endif | |||
| return true; | |||
| } | |||
| @@ -51,9 +64,7 @@ bool InitializeCluster() { return cluster::ClusterContext::instance()->Initializ | |||
| bool FinalizeCluster() { return cluster::ClusterContext::instance()->Finalize(); } | |||
| bool InitializeCollective(const std::string &backend, const std::string &global_group_name) { | |||
| return collective::CollectiveManager::instance()->Initialize(backend, global_group_name); | |||
| } | |||
| bool InitializeCollective() { return collective::CollectiveManager::instance()->Initialize(); } | |||
| bool FinalizeCollective() { return collective::CollectiveManager::instance()->Finalize(); } | |||
| } // namespace distributed | |||
| @@ -31,7 +31,7 @@ namespace distributed { | |||
| // The static methods of MindSpore distributed execution. They can be exported by Pybind. | |||
| // Initialize and finalize distributed execution. | |||
| bool Initialize(const std::string &backend, const std::string &global_group_name); | |||
| bool Initialize(); | |||
| bool Finalize(); | |||
| // Initialize and finalize the cluster based on MindSpore communication framework. | |||
| @@ -39,7 +39,7 @@ bool InitializeCluster(); | |||
| bool FinalizeCluster(); | |||
| // Initialize and finalize collective communication for distributed execution. | |||
| bool InitializeCollective(const std::string &backend, const std::string &global_group_name); | |||
| bool InitializeCollective(); | |||
| bool FinalizeCollective(); | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||
| @@ -410,7 +410,8 @@ void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; } | |||
| core::ClusterConfig &PSContext::cluster_config() { | |||
| if (cluster_config_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The cluster config is empty."; | |||
| cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_); | |||
| MS_EXCEPTION_IF_NULL(cluster_config_); | |||
| } | |||
| return *cluster_config_; | |||
| } | |||
| @@ -16,6 +16,8 @@ | |||
| #include "runtime/device/gpu/distribution/collective_init.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "distributed/init.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -30,23 +32,32 @@ bool CollectiveInitializer::collective_inited() const { return collective_inited | |||
| const void *CollectiveInitializer::collective_handle() { return collective_handle_; } | |||
| void CollectiveInitializer::InitCollective() { | |||
| void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); | |||
| if (handle == nullptr) { | |||
| MS_LOG(EXCEPTION) | |||
| << "Loading libgpu_collective.so failed. Many reasons could cause this:\n" | |||
| "1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built " | |||
| "with distributed feature.\n" | |||
| "2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore " | |||
| "requires NCCL-2.7.6.\n" | |||
| "3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore " | |||
| "requires OpenMPI-4.0.3.\n"; | |||
| } | |||
| auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI")); | |||
| MS_EXCEPTION_IF_NULL(mpi_init_funcptr); | |||
| (*mpi_init_funcptr)(); | |||
| if (common::CheckUseMPI()) { | |||
| void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); | |||
| if (handle == nullptr) { | |||
| MS_LOG(EXCEPTION) | |||
| << "Loading libgpu_collective.so failed. Many reasons could cause this:\n" | |||
| "1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built " | |||
| "with distributed feature.\n" | |||
| "2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore " | |||
| "requires NCCL-2.7.6.\n" | |||
| "3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore " | |||
| "requires OpenMPI-4.0.3.\n"; | |||
| } | |||
| auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI")); | |||
| MS_EXCEPTION_IF_NULL(mpi_init_funcptr); | |||
| (*mpi_init_funcptr)(); | |||
| CollectiveInitializer::instance().collective_inited_ = true; | |||
| CollectiveInitializer::instance().collective_handle_ = handle; | |||
| // Because this method InitCollective is static, the non-static member variables should be accessed by | |||
| // CollectiveInitializer::instance(). | |||
| CollectiveInitializer::instance().use_mpi_ = true; | |||
| CollectiveInitializer::instance().collective_inited_ = true; | |||
| CollectiveInitializer::instance().collective_handle_ = handle; | |||
| } else { | |||
| if (!distributed::Initialize()) { | |||
| MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL."; | |||
| } | |||
| } | |||
| } | |||
| void CollectiveInitializer::FinalizeCollective() { | |||
| @@ -56,6 +67,69 @@ void CollectiveInitializer::FinalizeCollective() { | |||
| } | |||
| } | |||
| } | |||
| uint32_t CollectiveInitializer::local_rank_id() { | |||
| uint32_t local_rank_id; | |||
| if (common::CheckUseMPI()) { | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| auto get_local_rank_funcptr = | |||
| reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id")); | |||
| MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); | |||
| local_rank_id = IntToUint((*get_local_rank_funcptr)()); | |||
| } else { | |||
| local_rank_id = distributed::collective::CollectiveManager::instance()->local_rank_id(); | |||
| } | |||
| return local_rank_id; | |||
| } | |||
| bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name, | |||
| const std::vector<uint32_t> &group_ranks) { | |||
| if (common::CheckUseMPI()) { | |||
| return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks); | |||
| } else { | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| auto create_comm_group_funcptr = | |||
| reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup")); | |||
| MS_EXCEPTION_IF_NULL(create_comm_group_funcptr); | |||
| return (*create_comm_group_funcptr)(group_name, group_ranks); | |||
| } | |||
| } | |||
| bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) { | |||
| if (common::CheckUseMPI()) { | |||
| return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name); | |||
| } else { | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| auto destroy_group_funcptr = | |||
| reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup")); | |||
| MS_EXCEPTION_IF_NULL(destroy_group_funcptr); | |||
| return (*destroy_group_funcptr)(group_name); | |||
| } | |||
| } | |||
| uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) { | |||
| if (common::CheckUseMPI()) { | |||
| return distributed::collective::CollectiveManager::instance()->GetRankId(group_name); | |||
| } else { | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| auto get_rank_id_funcptr = | |||
| reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup")); | |||
| MS_EXCEPTION_IF_NULL(get_rank_id_funcptr); | |||
| return IntToUint((*get_rank_id_funcptr)(group_name)); | |||
| } | |||
| } | |||
| uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) { | |||
| if (common::CheckUseMPI()) { | |||
| return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name); | |||
| } else { | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| auto get_group_size_funcptr = | |||
| reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize")); | |||
| MS_EXCEPTION_IF_NULL(get_group_size_funcptr); | |||
| return IntToUint((*get_group_size_funcptr)(group_name)); | |||
| } | |||
| } | |||
| } // namespace gpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -42,12 +42,20 @@ class CollectiveInitializer { | |||
| static void InitCollective(); | |||
| static void FinalizeCollective(); | |||
| // The capsulation of the collective communication APIs for compatibility. | |||
| uint32_t local_rank_id(); | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks); | |||
| bool DestroyCommunicationGroup(const std::string &group_name); | |||
| uint32_t GetRankIDByGroup(const std::string &group_name); | |||
| uint32_t GetGroupSize(const std::string &group_name); | |||
| private: | |||
| CollectiveInitializer() : collective_inited_(false) {} | |||
| CollectiveInitializer() : use_mpi_(false), collective_inited_(false), collective_handle_(nullptr) {} | |||
| ~CollectiveInitializer() = default; | |||
| bool use_mpi_; | |||
| bool collective_inited_; | |||
| void *collective_handle_{nullptr}; | |||
| void *collective_handle_; | |||
| }; | |||
| } // namespace gpu | |||
| } // namespace device | |||
| @@ -23,6 +23,9 @@ endif() | |||
| if(ENABLE_CPU) | |||
| file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") | |||
| list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc") | |||
| if(WIN32) | |||
| list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/ms_collective_comm_lib.cc") | |||
| endif() | |||
| if(ENABLE_MPI) | |||
| set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc" | |||
| "cpu/mpi_communication_group.cc" | |||
| @@ -60,6 +60,8 @@ CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &gr | |||
| return groups_[group_name]; | |||
| } | |||
| const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; } | |||
| uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; } | |||
| uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; } | |||
| @@ -77,6 +77,10 @@ class CollectiveCommunicationLib { | |||
| return true; | |||
| } | |||
| // Returns the global group name of this collective communication library. For NCCL, it's 'nccl_world_group'. For | |||
| // HCCL, it's 'hccl_world_group'. | |||
| const std::string &global_group_name() const; | |||
| // Returns global rank id of this process. | |||
| uint32_t global_rank_id() const; | |||
| @@ -90,6 +94,9 @@ class CollectiveCommunicationLib { | |||
| // Whether this collective communication library is initialized. | |||
| bool initialized_; | |||
| // The global group name. | |||
| std::string global_group_name_; | |||
| // The global rank id of this process. Normally this range is 0 to `total process number - 1`. | |||
| uint32_t global_rank_id_; | |||
| @@ -47,5 +47,11 @@ uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) { | |||
| } | |||
| uint32_t CommunicationGroup::group_size() const { return size_; } | |||
| const std::vector<uint32_t> &CommunicationGroup::group_ranks() const { return group_ranks_; } | |||
| const std::map<uint32_t, uint32_t> &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; } | |||
| const std::map<uint32_t, uint32_t> &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -55,6 +55,11 @@ class CommunicationGroup { | |||
| // Return the size of this communication group. | |||
| uint32_t group_size() const; | |||
| // Return group ranks info. | |||
| const std::vector<uint32_t> &group_ranks() const; | |||
| const std::map<uint32_t, uint32_t> &global_to_group_ranks() const; | |||
| const std::map<uint32_t, uint32_t> &group_to_global_ranks() const; | |||
| protected: | |||
| // Whether this communication group is initialized. | |||
| bool initialized_; | |||
| @@ -34,6 +34,9 @@ | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "profiler/device/cpu/cpu_profiling.h" | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "runtime/hardware/cpu/ms_collective_comm_lib.h" | |||
| #endif | |||
| #ifndef ENABLE_SECURITY | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| #endif | |||
| @@ -296,6 +299,32 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad | |||
| return DoLaunchKernel(kernel_mod, inputs, workspace, outputs); | |||
| } | |||
| bool CPUDeviceContext::LoadCollectiveCommLib() { | |||
| bool using_mpi = common::CheckUseMPI(); | |||
| if (using_mpi) { | |||
| std::string mpi_comm_lib_name = "libmpi_collective.so"; | |||
| auto loader = std::make_shared<CollectiveCommLibLoader>(mpi_comm_lib_name); | |||
| MS_EXCEPTION_IF_NULL(loader); | |||
| if (!loader->Initialize()) { | |||
| MS_LOG(EXCEPTION) << "Failed to load mpi collective library."; | |||
| return false; | |||
| } | |||
| void *collective_comm_lib_handle = loader->collective_comm_lib_ptr(); | |||
| MS_EXCEPTION_IF_NULL(collective_comm_lib_handle); | |||
| auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle); | |||
| collective_comm_lib_ = instance_func(); | |||
| MS_EXCEPTION_IF_NULL(collective_comm_lib_); | |||
| } else { | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| collective_comm_lib_ = &MsCollectiveCommLib::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(collective_comm_lib_); | |||
| #endif | |||
| } | |||
| return true; | |||
| } | |||
| bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) const { | |||
| @@ -57,6 +57,8 @@ class CPUDeviceContext : public DeviceContext { | |||
| const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs, | |||
| bool is_dynamic_shape = false) const override; | |||
| bool LoadCollectiveCommLib() override; | |||
| private: | |||
| DISABLE_COPY_AND_ASSIGN(CPUDeviceContext); | |||
| @@ -19,6 +19,8 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| MPICollectiveCommLib::MPICollectiveCommLib() { global_group_name_ = kMPIGlobalGroupName; } | |||
| bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) { | |||
| if (initialized_) { | |||
| return false; | |||
| @@ -56,49 +58,7 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam | |||
| } | |||
| } // namespace cpu | |||
| // The exported APIs for 'dlsym' to load. | |||
| using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib; | |||
| CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); } | |||
| bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); } | |||
| bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); } | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) { | |||
| return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks); | |||
| } | |||
| bool DestroyCommunicationGroup(const std::string &group_name) { | |||
| return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name); | |||
| } | |||
| uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); } | |||
| uint32_t GetCommunicationGroupSize(const std::string &group_name) { | |||
| return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name); | |||
| } | |||
| bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); } | |||
| CommunicationGroupPtr GetGroup(const std::string &group_name) { | |||
| return MPICollectiveCommLib::GetInstance().GetGroup(group_name); | |||
| } | |||
| bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &group_name, void *stream) { | |||
| return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream); | |||
| } | |||
| bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, | |||
| uint32_t root_rank, const std::string &group_name, void *stream) { | |||
| return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank, | |||
| group_name, stream); | |||
| } | |||
| uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); } | |||
| uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); } | |||
| uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -23,10 +23,15 @@ | |||
| #include "runtime/hardware/collective/collective_communication_lib.h" | |||
| #include "runtime/hardware/cpu/mpi_communication_group.h" | |||
| #ifndef EXPORT_MPI_WRAPPER | |||
| #define EXPORT_MPI_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| class MPICollectiveCommLib : public CollectiveCommunicationLib { | |||
| constexpr char kMPIGlobalGroupName[] = "mpi_world_group"; | |||
| class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib { | |||
| public: | |||
| static MPICollectiveCommLib &GetInstance() { | |||
| static MPICollectiveCommLib instance; | |||
| @@ -49,35 +54,14 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib { | |||
| } | |||
| private: | |||
| MPICollectiveCommLib() = default; | |||
| MPICollectiveCommLib(); | |||
| ~MPICollectiveCommLib() override = default; | |||
| MPI_Group world_group_; | |||
| }; | |||
| } // namespace cpu | |||
| #ifndef EXPORT_MPI_WRAPPER | |||
| #define EXPORT_MPI_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); | |||
| extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX, | |||
| uint32_t global_rank_size = UINT32_MAX); | |||
| extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib(); | |||
| extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name, | |||
| const std::vector<uint32_t> &group_ranks); | |||
| extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name); | |||
| extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank(); | |||
| extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name); | |||
| extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, uint32_t root_rank, | |||
| const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id(); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id(); | |||
| extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size(); | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_ | |||
| @@ -19,6 +19,8 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| MsCollectiveCommLib::MsCollectiveCommLib() { global_group_name_ = kMSGlobalGroupName; } | |||
| bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) { | |||
| if (initialized_) { | |||
| return false; | |||
| @@ -30,8 +32,6 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_ | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::Finalize() { return true; } | |||
| bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, | |||
| const std::vector<uint32_t> &group_ranks) { | |||
| CHECK_RET((groups_.count(group_name) == 0), true, "The group " + group_name + " has already existed."); | |||
| @@ -22,10 +22,17 @@ | |||
| #include <string> | |||
| #include "runtime/hardware/collective/collective_communication_lib.h" | |||
| #include "runtime/hardware/cpu/ms_communication_group.h" | |||
| #include "distributed/cluster/cluster_context.h" | |||
| #include "fl/server/collective_ops_impl.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| constexpr char kMSGlobalGroupName[] = "ms_world_group"; | |||
| using ClusterContext = mindspore::distributed::cluster::ClusterContext; | |||
| using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl; | |||
| using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo; | |||
| // The collective communication library for MindSpore self developed communication framework. | |||
| class MsCollectiveCommLib : public CollectiveCommunicationLib { | |||
| public: | |||
| @@ -35,12 +42,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib { | |||
| } | |||
| bool Initialize(uint32_t global_rank = UINT32_MAX, uint32_t global_rank_size = UINT32_MAX) override; | |||
| bool Finalize() override; | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override; | |||
| private: | |||
| MsCollectiveCommLib() {} | |||
| MsCollectiveCommLib(); | |||
| ~MsCollectiveCommLib() override = default; | |||
| }; | |||
| } // namespace cpu | |||
| @@ -50,7 +50,7 @@ struct DeviceContextKey { | |||
| class DeviceContext { | |||
| public: | |||
| explicit DeviceContext(const DeviceContextKey &device_context_key) | |||
| : device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {} | |||
| : device_context_key_(device_context_key), collective_comm_lib_(nullptr) {} | |||
| virtual ~DeviceContext() = default; | |||
| // Initialize the device context. | |||
| @@ -150,7 +150,7 @@ class DeviceContext { | |||
| virtual bool LoadCollectiveCommLib() { return true; } | |||
| // Return collective communication object for caller to access | |||
| void *collective_comm_lib() const { return collective_comm_lib_ptr_; } | |||
| CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; } | |||
| // TODO(jiaorui): will be delete | |||
| // Dump all graphs. | |||
| @@ -159,8 +159,8 @@ class DeviceContext { | |||
| protected: | |||
| DeviceContextKey device_context_key_; | |||
| // The dynamically loaded handle for collective communication library by 'dlopen'. | |||
| void *collective_comm_lib_ptr_; | |||
| // The collective communication library. | |||
| CollectiveCommunicationLib *collective_comm_lib_; | |||
| }; | |||
| using DeviceContextPtr = std::shared_ptr<DeviceContext>; | |||
| } // namespace device | |||
| @@ -534,8 +534,12 @@ bool GPUDeviceContext::LoadCollectiveCommLib() { | |||
| MS_LOG(EXCEPTION) << "Loading NCCL collective library failed."; | |||
| return false; | |||
| } | |||
| collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr(); | |||
| MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_); | |||
| void *collective_comm_lib_handle = loader->collective_comm_lib_ptr(); | |||
| MS_EXCEPTION_IF_NULL(collective_comm_lib_handle); | |||
| auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle); | |||
| collective_comm_lib_ = instance_func(); | |||
| MS_EXCEPTION_IF_NULL(collective_comm_lib_); | |||
| return true; | |||
| #else | |||
| return false; | |||
| @@ -19,6 +19,8 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() { global_group_name_ = kNCCLGlobalGroupName; } | |||
| bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) { | |||
| if (initialized_) { | |||
| return false; | |||
| @@ -42,50 +44,7 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_ | |||
| } | |||
| } // namespace gpu | |||
| // The exported APIs for 'dlsym' to load. | |||
| using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib; | |||
| CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); } | |||
| bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) { | |||
| return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size); | |||
| } | |||
| bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); } | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) { | |||
| return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks); | |||
| } | |||
| bool DestroyCommunicationGroup(const std::string &group_name) { | |||
| return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name); | |||
| } | |||
| uint32_t GetRankId(const std::string &group_name) { | |||
| return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name); | |||
| } | |||
| uint32_t GetCommunicationGroupSize(const std::string &group_name) { | |||
| return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name); | |||
| } | |||
| bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); } | |||
| CommunicationGroupPtr GetGroup(const std::string &group_name) { | |||
| return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name); | |||
| } | |||
| bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, | |||
| const std::string &group_name, void *stream) { | |||
| return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, | |||
| stream); | |||
| } | |||
| bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, | |||
| uint32_t root_rank, const std::string &group_name, void *stream) { | |||
| return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank, | |||
| group_name, stream); | |||
| } | |||
| uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -24,11 +24,15 @@ | |||
| #include "runtime/hardware/collective/collective_communication_lib.h" | |||
| #include "runtime/hardware/gpu/nvidia_communication_group.h" | |||
| #ifndef EXPORT_NCCL_WRAPPER | |||
| #define EXPORT_NCCL_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; | |||
| class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { | |||
| constexpr char kNCCLGlobalGroupName[] = "nccl_world_group"; | |||
| class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicationLib { | |||
| public: | |||
| static NvidiaCollectiveCommLib &GetInstance() { | |||
| static NvidiaCollectiveCommLib instance; | |||
| @@ -50,31 +54,12 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { | |||
| } | |||
| private: | |||
| NvidiaCollectiveCommLib() = default; | |||
| NvidiaCollectiveCommLib(); | |||
| ~NvidiaCollectiveCommLib() override = default; | |||
| }; | |||
| } // namespace gpu | |||
| #ifndef EXPORT_NCCL_WRAPPER | |||
| #define EXPORT_NCCL_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX, | |||
| uint32_t global_rank_size = UINT32_MAX); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib(); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name, | |||
| const std::vector<uint32_t> &group_ranks); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name); | |||
| extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name); | |||
| extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank(); | |||
| extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, | |||
| mindspore::TypeId data_type, uint32_t root_rank, | |||
| const std::string &group_name, void *stream); | |||
| extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id(); | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_ | |||
| @@ -90,6 +90,16 @@ static inline bool IsLittleByteOrder() { | |||
| return false; | |||
| } | |||
| static inline bool CheckUseMPI() { | |||
| // If these OpenMPI environment variables are set, we consider this process is launched by OpenMPI. | |||
| std::string ompi_command_env = GetEnv("OMPI_COMMAND"); | |||
| std::string pmix_rank_env = GetEnv("PMIX_RANK"); | |||
| if (!ompi_command_env.empty() && !pmix_rank_env.empty()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| template <typename T> | |||
| bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) { | |||
| if (a == b) { | |||