Browse Source

Add implement of collective manager API

tags/v1.6.0
ZPaC 4 years ago
parent
commit
055a493903
13 changed files with 292 additions and 57 deletions
  1. +160
    -20
      mindspore/ccsrc/distributed/collective/collective_manager.cc
  2. +20
    -9
      mindspore/ccsrc/distributed/collective/collective_manager.h
  3. +1
    -0
      mindspore/ccsrc/distributed/constants.h
  4. +2
    -12
      mindspore/ccsrc/runtime/hardware/collective/collective_comm_lib_loader.h
  5. +5
    -0
      mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc
  6. +16
    -0
      mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h
  7. +3
    -4
      mindspore/ccsrc/runtime/hardware/collective/communication_group.h
  8. +19
    -2
      mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc
  9. +19
    -2
      mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h
  10. +20
    -2
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc
  11. +19
    -2
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h
  12. +7
    -3
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.cc
  13. +1
    -1
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.h

+ 160
- 20
mindspore/ccsrc/distributed/collective/collective_manager.cc View File

@@ -28,7 +28,9 @@ CollectiveManager::CollectiveManager()
host_ctx_(nullptr),
device_ctx_(nullptr),
host_comm_lib_(nullptr),
host_comm_lib_instance_(nullptr),
device_comm_lib_(nullptr),
device_comm_lib_instance_(nullptr),
global_rank_id_(0),
local_rank_id_(0),
global_rank_size_(0),
@@ -61,13 +63,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
return false;
}

MS_EXCEPTION_IF_NULL(host_comm_lib_);
// Step 2: Create global communication group on host side.
if (!CreateHostGlobalCommGroup(global_group_name)) {
// Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not
// be necessary.
// Step 2: Assign local rank id(device id) for this process.
if (!AssignLocalRank(global_group_name)) {
MS_LOG(ERROR) << "Failed to assign local rank id.";
return false;
}

// Step 3: Initialize device side collective communication.
if (!InitDeviceCommLib(backend)) {
MS_LOG(ERROR) << "Failed to initialize device communication library.";
return false;
}

// Step 4: Create global communication group.
if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to initialize host communication library.";
return false;
}
// Step 3: Assign local rank id(device id) for this process.

MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << ".";
return true;
@@ -75,25 +89,81 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string

bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
MS_EXCEPTION_IF_NULL(host_comm_lib_);
MS_EXCEPTION_IF_NULL(device_comm_lib_);
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
// Step 1: Create communication group on host side.
// Step 2: Generate device information of the root node.
// Step 3: Broadcast the device root information to all nodes.
// Step 4: Create communication group on device side.
if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side.";
return false;
}

// Step 2: Create communication group on device side.
if (!device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on device side.";
return false;
}

// Step 3: Generate device information of the root node.
CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
MS_EXCEPTION_IF_NULL(group);
size_t root_info_size = 0;
void *root_info = group->GenerateRootInfo(&root_info_size);
MS_EXCEPTION_IF_NULL(root_info);

// Step 4: Broadcast the device root information to all nodes on host side.
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt, 0, group_name,
nullptr)) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
return false;
}

// Step 5: Initialize communication group on the device side.
if (!group->Initialize(root_info)) {
MS_LOG(ERROR) << "Initialize group on the device side failed.";
return false;
}
return true;
}

bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) { return true; }
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->DestroyCommunicationGroup(group_name)) {
MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the host side.";
return false;
}

MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->DestroyCommunicationGroup(group_name)) {
MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the device side.";
return false;
}
return true;
}

uint32_t CollectiveManager::GetRankId(const std::string &group_name) { return 0; }
uint32_t CollectiveManager::GetRankId(const std::string &group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
return host_comm_lib_instance_->GetRankId(group_name);
}

uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) { return 0; }
uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
return host_comm_lib_instance_->GetGroupSize(group_name);
}

bool CollectiveManager::Finalize() {
if (finalized_) {
return true;
}

MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize host communication library.";
}

MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize device communication library.";
}
return true;
}

@@ -105,19 +175,34 @@ bool CollectiveManager::InitHostCommlib() {
MS_LOG(ERROR) << "Failed to load communication library on the host side.";
return false;
}
return true;
}

bool CollectiveManager::CreateHostGlobalCommGroup(const std::string &global_group_name) {
host_comm_lib_ = host_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_);
if (global_group_ranks_.empty()) {
MS_LOG(ERROR) << "The global group rank list is empty.";
auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_);
host_comm_lib_instance_ = instance_func();
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);

// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
// MindSpore communication. For other communication libraries, global rank id and size is generated by itself, e.g.,
// OpenMPI, and parameters 'global_rank_id_', 'global_rank_size_' will not be used.
MS_LOG(INFO) << "Start initializing communication library on host side...";
if (!host_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) {
MS_LOG(ERROR) << "Failed to initialize communication library on host side.";
return false;
}

// Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
global_rank_id_ = host_comm_lib_instance_->global_rank_id();
global_rank_size_ = host_comm_lib_instance_->global_rank_size();
for (uint32_t i = 0; i < global_rank_size_; i++) {
global_group_ranks_.push_back(i);
}

MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
<< ", global rank size: " << global_rank_size_;
return true;
}

bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t device_id) {
bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
std::string device_name;
if (backend == "nccl") {
device_name = "GPU";
@@ -128,13 +213,68 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t d
return false;
}

device::DeviceContextKey device_key = {device_name, device_id};
device::DeviceContextKey device_key = {device_name, local_rank_id_};
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
MS_EXCEPTION_IF_NULL(device_ctx_);
if (!device_ctx_->LoadCollectiveCommLib()) {
MS_LOG(ERROR) << "Failed to load communication library on the device side.";
return false;
}
device_comm_lib_ = device_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(device_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_);
device_comm_lib_instance_ = instance_func();
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);

MS_LOG(INFO) << "Start initializing communication library on device side...";
if (!device_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) {
MS_LOG(ERROR) << "Failed to initialize communication library on device side.";
return false;
}
MS_LOG(INFO) << "Communication library on device side is successfully initialized.";
return true;
}

bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
char host_name[MAX_HOSTNAME_LEN] = {0};
#ifndef _WIN32
if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
MS_LOG(ERROR) << "Failed to get host name.";
return false;
}
#endif
MS_LOG(INFO) << "Host name for rank " << global_rank_id_ << " is " << host_name;

// Generate host name hash for every process. The host names of different physical machine should not be the same so
// that local rank id won't repeat.
size_t host_hash = std::hash<std::string>()(host_name);
const uint32_t kGlobalRankSize = global_rank_size_;
size_t all_host_hashs[kGlobalRankSize];
if (global_rank_id_ >= global_rank_size_) {
MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
<< global_rank_size_;
return false;
}
all_host_hashs[global_rank_id_] = host_hash;

MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// AllGather host names across the global communication group.
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt,
global_group_name, nullptr)) {
MS_LOG(ERROR) << "AllGather for host names failed.";
return false;
}

// Accumulate rank id.
for (uint32_t rank = 0; rank < global_rank_size_; rank++) {
if (rank == global_rank_id_) {
break;
}
if (all_host_hashs[rank] == all_host_hashs[global_rank_id_]) {
local_rank_id_++;
}
}
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_;
return true;
}
} // namespace collective


+ 20
- 9
mindspore/ccsrc/distributed/collective/collective_manager.h View File

@@ -22,6 +22,7 @@
#include <vector>
#include <atomic>
#include "utils/ms_utils.h"
#include "distributed/constants.h"
#include "runtime/hardware/device_context_manager.h"

namespace mindspore {
@@ -30,6 +31,8 @@ namespace collective {
using DeviceContext = device::DeviceContext;
using DeviceContextKey = device::DeviceContextKey;
using DeviceContextManager = device::DeviceContextManager;
using CollectiveCommunicationLib = device::CollectiveCommunicationLib;
using CommunicationGroupPtr = device::CommunicationGroupPtr;

// The collective communication API.
// MindSpore uses OpenMPI on CPU, NCCL on GPU, HCCL on Ascend, to achieve distributed training.
@@ -43,6 +46,9 @@ class CollectiveManager {
// Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL.
bool Initialize(const std::string &backend, const std::string &global_group_name);

// Finalize the collective communication.
bool Finalize();

// Create communication group.
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);

@@ -55,8 +61,10 @@ class CollectiveManager {
// Get the size of the specified group.
uint32_t GetGroupSize(const std::string &group_name);

// Finalize the collective communication.
bool Finalize();
// In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
// framework, they're generated by cluster::ClusterContext.
uint32_t set_global_rank_id();
uint32_t set_global_rank_size();

private:
CollectiveManager();
@@ -64,14 +72,11 @@ class CollectiveManager {
// Initialize communication library on host side.
bool InitHostCommlib();

// Create world communication group on the host side.
bool CreateHostGlobalCommGroup(const std::string &global_group_name);

// Initialize communication library on device side.
bool InitDeviceCommLib(const std::string &backend, uint32_t device_id);
bool InitDeviceCommLib(const std::string &backend);

// Create world communication group on the device side.
bool CreateDeviceGlobalCommGroup(const std::string &global_group_name);
// Assign the local rank id for this process.
bool AssignLocalRank(const std::string &global_group_name);

std::atomic_bool inited_;
std::atomic_bool finalized_;
@@ -81,9 +86,15 @@ class CollectiveManager {
DeviceContext *host_ctx_;
DeviceContext *device_ctx_;

// The dynamically loaded handle for collective communication library by 'dlopen'.
// Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
// framework.
void *host_comm_lib_;
CollectiveCommunicationLib *host_comm_lib_instance_;

// Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
// When only CPU backend is used, device communication library should not be initialized.
void *device_comm_lib_;
CollectiveCommunicationLib *device_comm_lib_instance_;

// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
uint32_t global_rank_id_;


+ 1
- 0
mindspore/ccsrc/distributed/constants.h View File

@@ -34,6 +34,7 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
const std::set<std::string> kValidRoleName = {kEnvRoleOfServer, kEnvRoleOfWorker, kEnvRoleOfScheduler};

constexpr char kLocalHost[] = "127.0.0.1";
constexpr int MAX_HOSTNAME_LEN = 1024;
const uint16_t kDefaultSchedPort = 6667;
const uint16_t kMaxPort = 65535;
} // namespace distributed


+ 2
- 12
mindspore/ccsrc/runtime/hardware/collective/collective_comm_lib_loader.h View File

@@ -21,6 +21,7 @@
#include <memory>
#include <vector>
#include "utils/dlopen_macro.h"
#include "runtime/hardware/collective/collective_communication_lib.h"

namespace mindspore {
namespace device {
@@ -51,17 +52,6 @@ using CollectiveCommLibLoaderPtr = std::shared_ptr<CollectiveCommLibLoader>;
} // namespace device
} // namespace mindspore

#ifndef _WIN32
// The exported symbols of collective communication shared library is registered here.
ORIGIN_METHOD(InitializeCollectiveLib, bool, uint32_t, uint32_t)
ORIGIN_METHOD(FinalizeCollectiveLib, bool)
ORIGIN_METHOD(CreateCommunicationGroup, bool, const std::string &, const std::vector<uint32_t> &)
ORIGIN_METHOD(DestroyCommunicationGroup, bool, const std::string &)
ORIGIN_METHOD(GetRankId, uint32_t, const std::string &)
ORIGIN_METHOD(GetCommunicationGroupSize, uint32_t, const std::string &)
ORIGIN_METHOD(AssignLocalRank, bool)
ORIGIN_METHOD(global_rank_id, uint32_t)
ORIGIN_METHOD(local_rank_id, uint32_t)
ORIGIN_METHOD(global_rank_size, uint32_t)
#endif
ORIGIN_METHOD(communication_lib_instance, mindspore::device::CollectiveCommunicationLib *)
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_

+ 5
- 0
mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc View File

@@ -55,6 +55,11 @@ uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name)
return group->group_size();
}

CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &group_name) {
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
return groups_[group_name];
}

uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }

uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }


+ 16
- 0
mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h View File

@@ -21,6 +21,7 @@
#include <memory>
#include <vector>
#include <string>
#include "ir/dtype/type_id.h"
#include "runtime/hardware/collective/communication_group.h"

namespace mindspore {
@@ -61,6 +62,21 @@ class CollectiveCommunicationLib {
// Assign the local rank id for this process. Normally used by collective communication library on the host side.
virtual bool AssignLocalRank() { return true; }

// Return communication group pointer.
virtual CommunicationGroupPtr GetGroup(const std::string &group_name);

// Primitive of AllGather operation.
virtual bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) {
return true;
}

// Primitive of Broadcast operation.
virtual bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return true;
}

// Returns global rank id of this process.
uint32_t global_rank_id() const;



+ 3
- 4
mindspore/ccsrc/runtime/hardware/collective/communication_group.h View File

@@ -44,10 +44,9 @@ class CommunicationGroup {
// Finalize the communication group. For example, destroy the group, etc.
virtual bool Finalize() = 0;

// Return the root rank's information. Only root rank of one group could call this method.Normally this is used for
// collective libraries on the device side. For NCCL group, it returns 'ncclUniqueId'. For HCCL group, it returns
// 'HcclRootInfo'.
virtual void *GenerateRootInfo() { return nullptr; }
// Return the root rank's information and its size. Normally this is used for collective libraries on the device side.
// For NCCL group, it returns a pointer to 'ncclUniqueId'. For HCCL group, it returns a pointer to 'HcclRootInfo'.
virtual void *GenerateRootInfo(size_t *root_info_size) { return nullptr; }

// Get group or global rank for the given rank.
uint32_t GetGroupRank(uint32_t global_rank);


+ 19
- 2
mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc View File

@@ -55,11 +55,12 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam
return true;
}
} // namespace cpu
} // namespace device
} // namespace mindspore

// The exported APIs for 'dlsym' to load.
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;

CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); }

bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); }

bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
@@ -80,8 +81,24 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) {

bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }

CommunicationGroupPtr GetGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroup(group_name);
}

bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}

uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); }

uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }

uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); }
} // namespace device
} // namespace mindspore

+ 19
- 2
mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h View File

@@ -38,6 +38,16 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
// Override creating method. Reuse destroying method in base class CollectiveCommunicationLib.
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;

bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) override {
return true;
}

bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream) override {
return true;
}

private:
MPICollectiveCommLib() = default;
~MPICollectiveCommLib() override = default;
@@ -45,12 +55,11 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
MPI_Group world_group_;
};
} // namespace cpu
} // namespace device
} // namespace mindspore

#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
@@ -60,7 +69,15 @@ extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_

+ 20
- 2
mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc View File

@@ -41,11 +41,11 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_
return true;
}
} // namespace gpu
} // namespace device
} // namespace mindspore

// The exported APIs for 'dlsym' to load.
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); }

bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
}
@@ -70,4 +70,22 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) {

bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }

CommunicationGroupPtr GetGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name);
}

bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name,
stream);
}

bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}

uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
} // namespace device
} // namespace mindspore

+ 19
- 2
mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h View File

@@ -39,17 +39,26 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {

bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;

bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) override {
return true;
}

bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream) override {
return true;
}

private:
NvidiaCollectiveCommLib() = default;
~NvidiaCollectiveCommLib() override = default;
};
} // namespace gpu
} // namespace device
} // namespace mindspore

#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
@@ -59,5 +68,13 @@ extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_

+ 7
- 3
mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.cc View File

@@ -21,7 +21,7 @@ namespace device {
namespace gpu {
NvidiaCommunicationGroup::NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
uint32_t global_rank)
: CommunicationGroup(name, group_ranks, global_rank) {}
: CommunicationGroup(name, group_ranks, global_rank), unique_id_({}), comm_(nullptr) {}

bool NvidiaCommunicationGroup::Initialize(void *root_info) {
if (initialized_) {
@@ -50,8 +50,12 @@ bool NvidiaCommunicationGroup::Finalize() {
return true;
}

void *NvidiaCommunicationGroup::GenerateRootInfo() {
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
void *NvidiaCommunicationGroup::GenerateRootInfo(size_t *root_info_size) {
*root_info_size = sizeof(unique_id_);
uint32_t group_rank = GetGroupRank(global_rank_);
if (group_rank == 0) {
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
}
return &unique_id_;
}
} // namespace gpu


+ 1
- 1
mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.h View File

@@ -37,7 +37,7 @@ class NvidiaCommunicationGroup : public CommunicationGroup {
bool Initialize(void *root_info) override;
bool Finalize() override;

void *GenerateRootInfo() override;
void *GenerateRootInfo(size_t *root_info_size) override;

private:
// The NCCL unique id for this group. Used to initialize this group's communicator.


Loading…
Cancel
Save