Browse Source

1.Purge not used API.

2.Adapt for collective_init.h
tags/v1.6.0
ZPaC 4 years ago
parent
commit
2b7429c5d2
28 changed files with 281 additions and 214 deletions
  1. +1
    -1
      mindspore/ccsrc/distributed/cluster/cluster_context.cc
  2. +2
    -1
      mindspore/ccsrc/distributed/cluster/cluster_context.h
  3. +0
    -2
      mindspore/ccsrc/distributed/cluster/dummy_cluster_context.cc
  4. +0
    -1
      mindspore/ccsrc/distributed/cluster/dummy_cluster_context.h
  5. +38
    -33
      mindspore/ccsrc/distributed/collective/collective_manager.cc
  6. +15
    -8
      mindspore/ccsrc/distributed/collective/collective_manager.h
  7. +1
    -0
      mindspore/ccsrc/distributed/constants.h
  8. +18
    -7
      mindspore/ccsrc/distributed/init.cc
  9. +2
    -2
      mindspore/ccsrc/distributed/init.h
  10. +2
    -1
      mindspore/ccsrc/ps/ps_context.cc
  11. +90
    -16
      mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc
  12. +10
    -2
      mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h
  13. +3
    -0
      mindspore/ccsrc/runtime/hardware/CMakeLists.txt
  14. +2
    -0
      mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc
  15. +7
    -0
      mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h
  16. +6
    -0
      mindspore/ccsrc/runtime/hardware/collective/communication_group.cc
  17. +5
    -0
      mindspore/ccsrc/runtime/hardware/collective/communication_group.h
  18. +29
    -0
      mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc
  19. +2
    -0
      mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h
  20. +2
    -42
      mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc
  21. +7
    -23
      mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h
  22. +2
    -2
      mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.cc
  23. +8
    -2
      mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h
  24. +4
    -4
      mindspore/ccsrc/runtime/hardware/device_context.h
  25. +6
    -2
      mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc
  26. +2
    -43
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc
  27. +7
    -22
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h
  28. +10
    -0
      mindspore/core/utils/ms_utils.h

+ 1
- 1
mindspore/ccsrc/distributed/cluster/cluster_context.cc View File

@@ -85,7 +85,7 @@ bool ClusterContext::Finalize() {
return true;
}

std::string ClusterContext::node_role() const { return node_role_; }
const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }

void ClusterContext::InitClusterConfig() {
InitNodeRole();


+ 2
- 1
mindspore/ccsrc/distributed/cluster/cluster_context.h View File

@@ -49,7 +49,8 @@ class ClusterContext {
// Finalize the cluster and process exits.
bool Finalize();

std::string node_role() const;
// Return node object of this process.
const std::shared_ptr<ps::core::Node> &node() const;

private:
ClusterContext();


+ 0
- 2
mindspore/ccsrc/distributed/cluster/dummy_cluster_context.cc View File

@@ -32,8 +32,6 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
bool ClusterContext::Initialize() const { return true; }

bool ClusterContext::Finalize() const { return true; }

std::string ClusterContext::node_role() const { return ""; }
} // namespace cluster
} // namespace distributed
} // namespace mindspore

+ 0
- 1
mindspore/ccsrc/distributed/cluster/dummy_cluster_context.h View File

@@ -39,7 +39,6 @@ class ClusterContext {

bool Initialize() const;
bool Finalize() const;
std::string node_role() const;

private:
ClusterContext() = default;


+ 38
- 33
mindspore/ccsrc/distributed/collective/collective_manager.cc View File

@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include <memory>
#include "utils/ms_context.h"

namespace mindspore {
namespace distributed {
@@ -27,9 +28,7 @@ CollectiveManager::CollectiveManager()
finalized_(true),
host_ctx_(nullptr),
device_ctx_(nullptr),
host_comm_lib_(nullptr),
host_comm_lib_instance_(nullptr),
device_comm_lib_(nullptr),
device_comm_lib_instance_(nullptr),
global_rank_id_(0),
local_rank_id_(0),
@@ -51,11 +50,13 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
return instance;
}

bool CollectiveManager::Initialize(const std::string &backend, const std::string &global_group_name) {
bool CollectiveManager::Initialize() {
if (inited_) {
return true;
}
MS_LOG(INFO) << "Start initializing collective communication for backend: " << backend << "...";

device_type_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type_ << "...";

// Step 1: Initialize host side collective communication.
if (!InitHostCommlib()) {
@@ -66,24 +67,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
// Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not
// be necessary.
// Step 2: Assign local rank id(device id) for this process.
if (!AssignLocalRank(global_group_name)) {
if (!AssignLocalRank()) {
MS_LOG(ERROR) << "Failed to assign local rank id.";
return false;
}

// Step 3: Initialize device side collective communication.
if (!InitDeviceCommLib(backend)) {
if (!InitDeviceCommLib()) {
MS_LOG(ERROR) << "Failed to initialize device communication library.";
return false;
}

// Step 4: Create global communication group.
if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) {
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!CreateCommunicationGroup(device_comm_lib_instance_->global_group_name(), global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to initialize host communication library.";
return false;
}

MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << ".";
MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
return true;
}

@@ -91,7 +93,7 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
// Step 1: Create communication group on host side.
// Step 1: Create communication group on host side if.
if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side.";
return false;
@@ -167,6 +169,12 @@ bool CollectiveManager::Finalize() {
return true;
}

void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }

void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }

uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; }

bool CollectiveManager::InitHostCommlib() {
device::DeviceContextKey host_key = {"CPU", 0};
host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
@@ -175,10 +183,7 @@ bool CollectiveManager::InitHostCommlib() {
MS_LOG(ERROR) << "Failed to load communication library on the host side.";
return false;
}
host_comm_lib_ = host_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_);
host_comm_lib_instance_ = instance_func();
host_comm_lib_instance_ = host_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);

// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
@@ -197,33 +202,30 @@ bool CollectiveManager::InitHostCommlib() {
global_group_ranks_.push_back(i);
}

// Create world group on host side for AllGather operation of host name while assigning local rank.
host_global_group_name_ = host_comm_lib_instance_->global_group_name();
if (!host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to create communication group " << host_global_group_name_ << " on host side.";
return false;
}

MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
<< ", global rank size: " << global_rank_size_;
return true;
}

bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
std::string device_name;
if (backend == "nccl") {
device_name = "GPU";
} else if (backend == "hccl") {
device_name = "Ascend";
} else {
MS_LOG(ERROR) << "Backend " << backend << " is not supported.";
return false;
}

device::DeviceContextKey device_key = {device_name, local_rank_id_};
bool CollectiveManager::InitDeviceCommLib() {
device::DeviceContextKey device_key = {device_type_, local_rank_id_};
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
MS_EXCEPTION_IF_NULL(device_ctx_);
// We can initialize device context now because device id(local_rank_id_) is already assigned.
device_ctx_->Initialize();

if (!device_ctx_->LoadCollectiveCommLib()) {
MS_LOG(ERROR) << "Failed to load communication library on the device side.";
return false;
}
device_comm_lib_ = device_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(device_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_);
device_comm_lib_instance_ = instance_func();
device_comm_lib_instance_ = device_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);

MS_LOG(INFO) << "Start initializing communication library on device side...";
@@ -235,7 +237,7 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
return true;
}

bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
bool CollectiveManager::AssignLocalRank() {
char host_name[MAX_HOSTNAME_LEN] = {0};
#ifndef _WIN32
if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
@@ -259,8 +261,8 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {

MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// AllGather host names across the global communication group.
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt,
global_group_name, nullptr)) {
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeInt,
host_global_group_name_, nullptr)) {
MS_LOG(ERROR) << "AllGather for host names failed.";
return false;
}
@@ -274,7 +276,10 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
local_rank_id_++;
}
}
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_;

MsContext::GetInstance()->set_param<uint32_t>(MS_CTX_DEVICE_ID, local_rank_id_);
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_
<< ". device_id of ms_context is set.";
return true;
}
} // namespace collective


+ 15
- 8
mindspore/ccsrc/distributed/collective/collective_manager.h View File

@@ -43,8 +43,8 @@ class CollectiveManager {
DISABLE_COPY_AND_ASSIGN(CollectiveManager);
static std::shared_ptr<CollectiveManager> instance();

// Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL.
bool Initialize(const std::string &backend, const std::string &global_group_name);
// Initialize the collective communication for distributed training. The backend type is read from MindSpore context.
bool Initialize();

// Finalize the collective communication.
bool Finalize();
@@ -63,8 +63,10 @@ class CollectiveManager {

// In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
// framework, they're generated by cluster::ClusterContext.
uint32_t set_global_rank_id();
uint32_t set_global_rank_size();
void set_global_rank_id(uint32_t global_rank_id);
void set_global_rank_size(uint32_t global_rank_size);

uint32_t local_rank_id() const;

private:
CollectiveManager();
@@ -73,14 +75,17 @@ class CollectiveManager {
bool InitHostCommlib();

// Initialize communication library on device side.
bool InitDeviceCommLib(const std::string &backend);
bool InitDeviceCommLib();

// Assign the local rank id for this process.
bool AssignLocalRank(const std::string &global_group_name);
bool AssignLocalRank();

std::atomic_bool inited_;
std::atomic_bool finalized_;

// The device type read from MindSpore context.
std::string device_type_;

// The device context on both host and device side. They are used to access the communication library on different
// devices.
DeviceContext *host_ctx_;
@@ -88,12 +93,10 @@ class CollectiveManager {

// Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
// framework.
void *host_comm_lib_;
CollectiveCommunicationLib *host_comm_lib_instance_;

// Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
// When only CPU backend is used, device communication library should not be initialized.
void *device_comm_lib_;
CollectiveCommunicationLib *device_comm_lib_instance_;

// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
@@ -107,6 +110,10 @@ class CollectiveManager {

// Global group ranks.
std::vector<uint32_t> global_group_ranks_;

// The global group name on the host side. This is used for Creating global group on host side for AllGather operation
// of host name while assigning local rank.
std::string host_global_group_name_;
};
} // namespace collective
} // namespace distributed


+ 1
- 0
mindspore/ccsrc/distributed/constants.h View File

@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_

#include <set>
#include <map>
#include <string>

namespace mindspore {


+ 18
- 7
mindspore/ccsrc/distributed/init.cc View File

@@ -20,16 +20,29 @@

namespace mindspore {
namespace distributed {
bool Initialize(const std::string &backend, const std::string &global_group_name) {
bool Initialize() {
if (!InitializeCluster()) {
MS_LOG(ERROR) << "Failed to initialize cluster.";
return false;
}

if (!InitializeCollective(backend, global_group_name)) {
MS_LOG(ERROR) << "Failed to initialize collective communication.";
return false;
#if ((defined ENABLE_CPU) && (!defined _WIN32))
// Server and Scheduler don't use collective communication library.
auto node = cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
if (node->role() != ps::core::NodeRole::SERVER && node->role() != ps::core::NodeRole::SCHEDULER) {
// Global rank id and size should be manually set if cluster is initialized by MindSpore communication framework.
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());

if (!InitializeCollective()) {
MS_LOG(ERROR) << "Failed to initialize collective communication.";
return false;
}
}
#endif
return true;
}

@@ -51,9 +64,7 @@ bool InitializeCluster() { return cluster::ClusterContext::instance()->Initializ

bool FinalizeCluster() { return cluster::ClusterContext::instance()->Finalize(); }

bool InitializeCollective(const std::string &backend, const std::string &global_group_name) {
return collective::CollectiveManager::instance()->Initialize(backend, global_group_name);
}
bool InitializeCollective() { return collective::CollectiveManager::instance()->Initialize(); }

bool FinalizeCollective() { return collective::CollectiveManager::instance()->Finalize(); }
} // namespace distributed


+ 2
- 2
mindspore/ccsrc/distributed/init.h View File

@@ -31,7 +31,7 @@ namespace distributed {
// The static methods of MindSpore distributed execution. They can be exported by Pybind.

// Initialize and finalize distributed execution.
bool Initialize(const std::string &backend, const std::string &global_group_name);
bool Initialize();
bool Finalize();

// Initialize and finalize the cluster based on MindSpore communication framework.
@@ -39,7 +39,7 @@ bool InitializeCluster();
bool FinalizeCluster();

// Initialize and finalize collective communication for distributed execution.
bool InitializeCollective(const std::string &backend, const std::string &global_group_name);
bool InitializeCollective();
bool FinalizeCollective();
} // namespace distributed
} // namespace mindspore


+ 2
- 1
mindspore/ccsrc/ps/ps_context.cc View File

@@ -410,7 +410,8 @@ void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }

core::ClusterConfig &PSContext::cluster_config() {
if (cluster_config_ == nullptr) {
MS_LOG(EXCEPTION) << "The cluster config is empty.";
cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
MS_EXCEPTION_IF_NULL(cluster_config_);
}
return *cluster_config_;
}


+ 90
- 16
mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc View File

@@ -16,6 +16,8 @@

#include "runtime/device/gpu/distribution/collective_init.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "distributed/init.h"

namespace mindspore {
namespace device {
@@ -30,23 +32,32 @@ bool CollectiveInitializer::collective_inited() const { return collective_inited
const void *CollectiveInitializer::collective_handle() { return collective_handle_; }

void CollectiveInitializer::InitCollective() {
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
if (handle == nullptr) {
MS_LOG(EXCEPTION)
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n"
"1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built "
"with distributed feature.\n"
"2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore "
"requires NCCL-2.7.6.\n"
"3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore "
"requires OpenMPI-4.0.3.\n";
}
auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI"));
MS_EXCEPTION_IF_NULL(mpi_init_funcptr);
(*mpi_init_funcptr)();
if (common::CheckUseMPI()) {
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
if (handle == nullptr) {
MS_LOG(EXCEPTION)
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n"
"1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built "
"with distributed feature.\n"
"2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore "
"requires NCCL-2.7.6.\n"
"3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore "
"requires OpenMPI-4.0.3.\n";
}
auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI"));
MS_EXCEPTION_IF_NULL(mpi_init_funcptr);
(*mpi_init_funcptr)();

CollectiveInitializer::instance().collective_inited_ = true;
CollectiveInitializer::instance().collective_handle_ = handle;
// Because this method InitCollective is static, the non-static member variables should be accessed by
// CollectiveInitializer::instance().
CollectiveInitializer::instance().use_mpi_ = true;
CollectiveInitializer::instance().collective_inited_ = true;
CollectiveInitializer::instance().collective_handle_ = handle;
} else {
if (!distributed::Initialize()) {
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL.";
}
}
}

void CollectiveInitializer::FinalizeCollective() {
@@ -56,6 +67,69 @@ void CollectiveInitializer::FinalizeCollective() {
}
}
}

uint32_t CollectiveInitializer::local_rank_id() {
uint32_t local_rank_id;
if (common::CheckUseMPI()) {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
local_rank_id = IntToUint((*get_local_rank_funcptr)());
} else {
local_rank_id = distributed::collective::CollectiveManager::instance()->local_rank_id();
}
return local_rank_id;
}

bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto create_comm_group_funcptr =
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
return (*create_comm_group_funcptr)(group_name, group_ranks);
}
}

bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto destroy_group_funcptr =
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
return (*destroy_group_funcptr)(group_name);
}
}

uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_rank_id_funcptr =
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
return IntToUint((*get_rank_id_funcptr)(group_name));
}
}

uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_group_size_funcptr =
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
return IntToUint((*get_group_size_funcptr)(group_name));
}
}
} // namespace gpu
} // namespace device
} // namespace mindspore

+ 10
- 2
mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h View File

@@ -42,12 +42,20 @@ class CollectiveInitializer {
static void InitCollective();
static void FinalizeCollective();

// The capsulation of the collective communication APIs for compatibility.
uint32_t local_rank_id();
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);
bool DestroyCommunicationGroup(const std::string &group_name);
uint32_t GetRankIDByGroup(const std::string &group_name);
uint32_t GetGroupSize(const std::string &group_name);

private:
CollectiveInitializer() : collective_inited_(false) {}
CollectiveInitializer() : use_mpi_(false), collective_inited_(false), collective_handle_(nullptr) {}
~CollectiveInitializer() = default;

bool use_mpi_;
bool collective_inited_;
void *collective_handle_{nullptr};
void *collective_handle_;
};
} // namespace gpu
} // namespace device


+ 3
- 0
mindspore/ccsrc/runtime/hardware/CMakeLists.txt View File

@@ -23,6 +23,9 @@ endif()
if(ENABLE_CPU)
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
if(WIN32)
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/ms_collective_comm_lib.cc")
endif()
if(ENABLE_MPI)
set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc"
"cpu/mpi_communication_group.cc"


+ 2
- 0
mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc View File

@@ -60,6 +60,8 @@ CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &gr
return groups_[group_name];
}

const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; }

uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }

uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }


+ 7
- 0
mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h View File

@@ -77,6 +77,10 @@ class CollectiveCommunicationLib {
return true;
}

// Returns the global group name of this collective communication library. For NCCL, it's 'nccl_world_group'. For
// HCCL, it's 'hccl_world_group'.
const std::string &global_group_name() const;

// Returns global rank id of this process.
uint32_t global_rank_id() const;

@@ -90,6 +94,9 @@ class CollectiveCommunicationLib {
// Whether this collective communication library is initialized.
bool initialized_;

// The global group name.
std::string global_group_name_;

// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
uint32_t global_rank_id_;



+ 6
- 0
mindspore/ccsrc/runtime/hardware/collective/communication_group.cc View File

@@ -47,5 +47,11 @@ uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
}

uint32_t CommunicationGroup::group_size() const { return size_; }

const std::vector<uint32_t> &CommunicationGroup::group_ranks() const { return group_ranks_; }

const std::map<uint32_t, uint32_t> &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; }

const std::map<uint32_t, uint32_t> &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; }
} // namespace device
} // namespace mindspore

+ 5
- 0
mindspore/ccsrc/runtime/hardware/collective/communication_group.h View File

@@ -55,6 +55,11 @@ class CommunicationGroup {
// Return the size of this communication group.
uint32_t group_size() const;

// Return group ranks info.
const std::vector<uint32_t> &group_ranks() const;
const std::map<uint32_t, uint32_t> &global_to_group_ranks() const;
const std::map<uint32_t, uint32_t> &group_to_global_ranks() const;

protected:
// Whether this communication group is initialized.
bool initialized_;


+ 29
- 0
mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc View File

@@ -34,6 +34,9 @@
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "profiler/device/cpu/cpu_profiling.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "runtime/hardware/cpu/ms_collective_comm_lib.h"
#endif
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
#endif
@@ -296,6 +299,32 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
return DoLaunchKernel(kernel_mod, inputs, workspace, outputs);
}

bool CPUDeviceContext::LoadCollectiveCommLib() {
bool using_mpi = common::CheckUseMPI();
if (using_mpi) {
std::string mpi_comm_lib_name = "libmpi_collective.so";
auto loader = std::make_shared<CollectiveCommLibLoader>(mpi_comm_lib_name);
MS_EXCEPTION_IF_NULL(loader);
if (!loader->Initialize()) {
MS_LOG(EXCEPTION) << "Failed to load mpi collective library.";
return false;
}

void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);

auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
collective_comm_lib_ = instance_func();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
} else {
#if ((defined ENABLE_CPU) && (!defined _WIN32))
collective_comm_lib_ = &MsCollectiveCommLib::GetInstance();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
#endif
}
return true;
}

bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const {


+ 2
- 0
mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h View File

@@ -57,6 +57,8 @@ class CPUDeviceContext : public DeviceContext {
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
bool is_dynamic_shape = false) const override;

bool LoadCollectiveCommLib() override;

private:
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);



+ 2
- 42
mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc View File

@@ -19,6 +19,8 @@
namespace mindspore {
namespace device {
namespace cpu {
MPICollectiveCommLib::MPICollectiveCommLib() { global_group_name_ = kMPIGlobalGroupName; }

bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) {
if (initialized_) {
return false;
@@ -56,49 +58,7 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam
}
} // namespace cpu

// The exported APIs for 'dlsym' to load.
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;

CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); }

bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); }

bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }

bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}

bool DestroyCommunicationGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}

uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); }

uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name);
}

bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }

CommunicationGroupPtr GetGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroup(group_name);
}

bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}

uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); }

uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }

uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); }
} // namespace device
} // namespace mindspore

+ 7
- 23
mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h View File

@@ -23,10 +23,15 @@
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/cpu/mpi_communication_group.h"

#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif

namespace mindspore {
namespace device {
namespace cpu {
class MPICollectiveCommLib : public CollectiveCommunicationLib {
constexpr char kMPIGlobalGroupName[] = "mpi_world_group";
class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib {
public:
static MPICollectiveCommLib &GetInstance() {
static MPICollectiveCommLib instance;
@@ -49,35 +54,14 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
}

private:
MPICollectiveCommLib() = default;
MPICollectiveCommLib();
~MPICollectiveCommLib() override = default;

MPI_Group world_group_;
};
} // namespace cpu

#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_

+ 2
- 2
mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.cc View File

@@ -19,6 +19,8 @@
namespace mindspore {
namespace device {
namespace cpu {
MsCollectiveCommLib::MsCollectiveCommLib() { global_group_name_ = kMSGlobalGroupName; }

bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
if (initialized_) {
return false;
@@ -30,8 +32,6 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_
return true;
}

bool MsCollectiveCommLib::Finalize() { return true; }

bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
CHECK_RET((groups_.count(group_name) == 0), true, "The group " + group_name + " has already existed.");


+ 8
- 2
mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h View File

@@ -22,10 +22,17 @@
#include <string>
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/cpu/ms_communication_group.h"
#include "distributed/cluster/cluster_context.h"
#include "fl/server/collective_ops_impl.h"

namespace mindspore {
namespace device {
namespace cpu {
constexpr char kMSGlobalGroupName[] = "ms_world_group";
using ClusterContext = mindspore::distributed::cluster::ClusterContext;
using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl;
using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;

// The collective communication library for MindSpore self developed communication framework.
class MsCollectiveCommLib : public CollectiveCommunicationLib {
public:
@@ -35,12 +42,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {
}

bool Initialize(uint32_t global_rank = UINT32_MAX, uint32_t global_rank_size = UINT32_MAX) override;
bool Finalize() override;

bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;

private:
MsCollectiveCommLib() {}
MsCollectiveCommLib();
~MsCollectiveCommLib() override = default;
};
} // namespace cpu


+ 4
- 4
mindspore/ccsrc/runtime/hardware/device_context.h View File

@@ -50,7 +50,7 @@ struct DeviceContextKey {
class DeviceContext {
public:
explicit DeviceContext(const DeviceContextKey &device_context_key)
: device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {}
: device_context_key_(device_context_key), collective_comm_lib_(nullptr) {}
virtual ~DeviceContext() = default;

// Initialize the device context.
@@ -150,7 +150,7 @@ class DeviceContext {
virtual bool LoadCollectiveCommLib() { return true; }

// Return collective communication object for caller to access
void *collective_comm_lib() const { return collective_comm_lib_ptr_; }
CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }

// TODO(jiaorui): will be delete
// Dump all graphs.
@@ -159,8 +159,8 @@ class DeviceContext {
protected:
DeviceContextKey device_context_key_;

// The dynamically loaded handle for collective communication library by 'dlopen'.
void *collective_comm_lib_ptr_;
// The collective communication library.
CollectiveCommunicationLib *collective_comm_lib_;
};
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
} // namespace device


+ 6
- 2
mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc View File

@@ -534,8 +534,12 @@ bool GPUDeviceContext::LoadCollectiveCommLib() {
MS_LOG(EXCEPTION) << "Loading NCCL collective library failed.";
return false;
}
collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);

auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
collective_comm_lib_ = instance_func();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
return true;
#else
return false;


+ 2
- 43
mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc View File

@@ -19,6 +19,8 @@
namespace mindspore {
namespace device {
namespace gpu {
NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() { global_group_name_ = kNCCLGlobalGroupName; }

bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
if (initialized_) {
return false;
@@ -42,50 +44,7 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_
}
} // namespace gpu

// The exported APIs for 'dlsym' to load.
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); }

bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
}

bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); }

bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}

bool DestroyCommunicationGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}

uint32_t GetRankId(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name);
}

uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name);
}

bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }

CommunicationGroupPtr GetGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name);
}

bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name,
stream);
}

bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}

uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
} // namespace device
} // namespace mindspore

+ 7
- 22
mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h View File

@@ -24,11 +24,15 @@
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/gpu/nvidia_communication_group.h"

#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif

namespace mindspore {
namespace device {
namespace gpu {
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
constexpr char kNCCLGlobalGroupName[] = "nccl_world_group";
class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
public:
static NvidiaCollectiveCommLib &GetInstance() {
static NvidiaCollectiveCommLib instance;
@@ -50,31 +54,12 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
}

private:
NvidiaCollectiveCommLib() = default;
NvidiaCollectiveCommLib();
~NvidiaCollectiveCommLib() override = default;
};
} // namespace gpu

#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_

+ 10
- 0
mindspore/core/utils/ms_utils.h View File

@@ -90,6 +90,16 @@ static inline bool IsLittleByteOrder() {
return false;
}

static inline bool CheckUseMPI() {
// If these OpenMPI environment variables are set, we consider this process is launched by OpenMPI.
std::string ompi_command_env = GetEnv("OMPI_COMMAND");
std::string pmix_rank_env = GetEnv("PMIX_RANK");
if (!ompi_command_env.empty() && !pmix_rank_env.empty()) {
return true;
}
return false;
}

template <typename T>
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
if (a == b) {


Loading…
Cancel
Save