Browse Source

!31811 Optimize the protocol of initialization collective communication

Merge pull request !31811 from zyli2020/master
r1.7
i-robot Gitee 4 years ago
parent
commit
8fb35ecb95
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 463 additions and 23 deletions
  1. +4
    -14
      mindspore/ccsrc/distributed/collective/collective_manager.cc
  2. +162
    -0
      mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_comm_lib.cc
  3. +21
    -0
      mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_comm_lib.h
  4. +6
    -5
      mindspore/ccsrc/ps/core/abstract_node.cc
  5. +4
    -1
      mindspore/ccsrc/ps/core/abstract_node.h
  6. +7
    -0
      mindspore/ccsrc/ps/core/abstract_ps_node.cc
  7. +3
    -0
      mindspore/ccsrc/ps/core/abstract_ps_node.h
  8. +57
    -0
      mindspore/ccsrc/ps/core/protos/comm.proto
  9. +140
    -0
      mindspore/ccsrc/ps/core/ps_scheduler_node.cc
  10. +38
    -1
      mindspore/ccsrc/ps/core/ps_scheduler_node.h
  11. +1
    -0
      mindspore/ccsrc/ps/core/scheduler_node.cc
  12. +3
    -0
      mindspore/ccsrc/ps/core/scheduler_node.h
  13. +2
    -1
      mindspore/ccsrc/runtime/collective/collective_communication_lib.cc
  14. +15
    -1
      mindspore/ccsrc/runtime/collective/collective_communication_lib.h

+ 4
- 14
mindspore/ccsrc/distributed/collective/collective_manager.cc View File

@@ -120,18 +120,14 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
// Step 3: Generate device information of the root node.
CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
MS_EXCEPTION_IF_NULL(group);
bool is_root_node = (group->GetGroupRank(global_rank_id_) == 0);
size_t root_info_size = 0;
void *root_info = group->GenerateRootInfo(&root_info_size);
MS_EXCEPTION_IF_NULL(root_info);

// Step 4: Broadcast the device root information to all nodes on host side.
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt8, 0,
group_name)) {
if (!host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info)) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
runtime::recovery::RecoveryErrCode::kBroadcastUniqueIDFailed);
}
return false;
}

@@ -305,7 +301,7 @@ bool CollectiveManager::AssignLocalRank() {
// that local rank id won't repeat.
size_t host_hash = std::hash<std::string>()(host_name);
const uint32_t kGlobalRankSize = global_rank_size_;
size_t all_host_hashs[kGlobalRankSize];
std::vector<size_t> all_host_hashs(kGlobalRankSize);
if (global_rank_id_ >= global_rank_size_) {
MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
<< global_rank_size_;
@@ -314,13 +310,7 @@ bool CollectiveManager::AssignLocalRank() {
all_host_hashs[global_rank_id_] = host_hash;

MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// AllGather host names across the global communication group.
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeUInt64,
host_global_group_name_)) {
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
runtime::recovery::RecoveryErrCode::kAllGatherHostNameFailed);
}
if (!host_comm_lib_instance_->AllGatherHostHashName(host_hash, &all_host_hashs)) {
MS_LOG(ERROR) << "AllGather for host names failed.";
return false;
}


+ 162
- 0
mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_comm_lib.cc View File

@@ -34,6 +34,7 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_
global_rank_id_ = global_rank;
global_rank_size_ = global_rank_size;
initialized_ = true;
finalized_ = false;
return true;
}

@@ -50,6 +51,167 @@ bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name
return true;
}

bool MsCollectiveCommLib::AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) const {
CHECK_IF_NULL(host_hash_names);
while (!SendHostHashName(host_hash_name)) {
MS_LOG(WARNING) << "Send host hash name to scheduler failed, retrying...";
if (finalized_.load()) {
return false;
}

std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}

while (!QueryHostHashNames(host_hash_names)) {
MS_LOG(WARNING) << "Query host hash names from scheduler failed, retrying...";
if (finalized_.load()) {
return false;
}
std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}

return true;
}

bool MsCollectiveCommLib::BroadcastUniqueID(const std::string &group_name, bool is_root_node, size_t root_info_size,
void *root_info) const {
CHECK_IF_NULL(root_info);
if (is_root_node) {
while (!SendUniqueID(group_name, root_info_size, root_info)) {
MS_LOG(WARNING) << "Send unique id to scheduler failed, retrying...";
if (finalized_.load()) {
return false;
}

std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}
return true;
}

while (!QueryUniqueID(group_name, root_info_size, root_info)) {
MS_LOG(WARNING) << "Query unique id from scheduler failed, retrying...";
if (finalized_.load()) {
return false;
}

std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}
return true;
}

bool MsCollectiveCommLib::SendHostHashName(size_t host_hash_name) const {
CHECK_IF_NULL(node_);
ps::core::SendHostHashNameMessage send_host_name_msg;
send_host_name_msg.set_node_id(node_->node_id());
send_host_name_msg.set_rank_id(node_->rank_id());
send_host_name_msg.set_host_hash_name(host_hash_name);
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!node_->SendToScheduler(send_host_name_msg.SerializeAsString().data(),
send_host_name_msg.SerializeAsString().size(), NodeCommand::SEND_HOST_NAME, &output)) {
MS_LOG(WARNING) << "Failed to send host hash name request to scheduler.";
return false;
}

ps::core::GeneralResponseMsg resp_msg;
CHECK_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!resp_msg.is_success()) {
MS_LOG(WARNING) << "Send host hash name to scheduler failed.";
return false;
}
return true;
}

bool MsCollectiveCommLib::QueryHostHashNames(std::vector<size_t> *host_hash_names) const {
CHECK_IF_NULL(host_hash_names);
CHECK_IF_NULL(node_);
ps::core::GeneralQueryMessage general_query_msg;
general_query_msg.set_node_id(node_->node_id());
general_query_msg.set_rank_id(node_->rank_id());
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!node_->SendToScheduler(general_query_msg.SerializeAsString().data(),
general_query_msg.SerializeAsString().size(), NodeCommand::QUERY_HOST_NAMES, &output)) {
MS_LOG(WARNING) << "Failed to send query host name request to scheduler.";
return false;
}

ps::core::QueryHostHashNameRespMessage resp_msg;
CHECK_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!resp_msg.is_success()) {
MS_LOG(INFO) << "Query host hash name from scheduer failed, maybe scheduler has not received all host names.";
return false;
}

if (host_hash_names->size() != IntToSize(resp_msg.host_hash_names_size())) {
MS_LOG(ERROR) << "The host_hash_names container size: " << host_hash_names->size()
<< ", but received size: " << resp_msg.host_hash_names_size();
return false;
}

for (size_t i = 0; i < host_hash_names->size(); i++) {
(*host_hash_names)[i] = resp_msg.host_hash_names()[i];
}

return true;
}

bool MsCollectiveCommLib::SendUniqueID(const std::string &group_name, size_t root_info_size,
const void *root_info) const {
CHECK_IF_NULL(root_info);
CHECK_IF_NULL(node_);
ps::core::SendUniqueIDMessage send_unique_id_msg;
send_unique_id_msg.set_node_id(node_->node_id());
send_unique_id_msg.set_rank_id(node_->rank_id());
send_unique_id_msg.set_group_name(group_name);
send_unique_id_msg.set_unique_id(root_info, root_info_size);
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!node_->SendToScheduler(send_unique_id_msg.SerializeAsString().data(),
send_unique_id_msg.SerializeAsString().size(), NodeCommand::SEND_UNIQUE_ID, &output)) {
MS_LOG(WARNING) << "Failed to send unique id request to scheduler.";
return false;
}

ps::core::GeneralResponseMsg resp_msg;
CHECK_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!resp_msg.is_success()) {
MS_LOG(WARNING) << "Send unique id to scheduler failed.";
return false;
}
return true;
}

bool MsCollectiveCommLib::QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) const {
CHECK_IF_NULL(root_info);
CHECK_IF_NULL(node_);
ps::core::QueryUniqueIDMessage query_unique_id_msg;
query_unique_id_msg.set_node_id(node_->node_id());
query_unique_id_msg.set_rank_id(node_->rank_id());
query_unique_id_msg.set_group_name(group_name);
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!node_->SendToScheduler(query_unique_id_msg.SerializeAsString().data(),
query_unique_id_msg.SerializeAsString().size(), NodeCommand::QUERY_UNIQUE_ID, &output)) {
MS_LOG(WARNING) << "Failed to send query unique id request to scheduler.";
return false;
}

ps::core::QueryUniqueIDRespMessage resp_msg;
CHECK_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!resp_msg.is_success()) {
MS_LOG(INFO) << "Query unique id from scheduer failed, maybe scheduler has not received unique id.";
return false;
}

auto ret = memcpy_s(root_info, root_info_size, resp_msg.unique_id().data(), resp_msg.unique_id().length());
if (ret != EOK) {
MS_LOG(WARNING) << "The memcpy_s error, errorno(" << ret << ")";
return false;
}
return true;
}

bool MsCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &, void *) {
CHECK_IF_NULL(send_buff);


+ 21
- 0
mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_comm_lib.h View File

@@ -32,6 +32,10 @@ constexpr char kMSGlobalGroupName[] = "ms_world_group";
using ClusterContext = mindspore::distributed::cluster::ClusterContext;
using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl;
using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
using ps::core::NodeCommand;

// The time interval for send info or query info between worker and scheduler.
constexpr uint32_t kWaitDuration = 3;

// The collective communication library for MindSpore self developed communication framework.
class MsCollectiveCommLib : public CollectiveCommunicationLib {
@@ -45,6 +49,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {

bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;

bool AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) const override;

bool BroadcastUniqueID(const std::string &group_name, bool is_root_node, size_t root_info_size,
void *root_info) const override;

bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream = nullptr) override;

@@ -65,6 +74,18 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {
MsCollectiveCommLib();
~MsCollectiveCommLib() override = default;

// Send host hash name to scheduler.
bool SendHostHashName(size_t host_hash_name) const;

// Query host hash names of all nodes from scheduler.
bool QueryHostHashNames(std::vector<size_t> *host_hash_names) const;

// Send unique id to scheduler.
bool SendUniqueID(const std::string &group_name, size_t root_info_size, const void *root_info) const;

// Query unique id from scheduler.
bool QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) const;

std::shared_ptr<ps::core::AbstractNode> node_;
};
} // namespace cpu


+ 6
- 5
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -802,8 +802,8 @@ void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &m
}
}

void AbstractNode::ProcessActorRouteServiceResp(const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size) {
void AbstractNode::ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
@@ -1300,12 +1300,13 @@ void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
handlers_[NodeCommand::SEND_EVENT] = nullptr;
RegisterActorRouteTableRspHandler();
RegisterInitCollectCommResphandler();
}

void AbstractNode::RegisterActorRouteTableRspHandler() {
handlers_[NodeCommand::REGISTER_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp;
handlers_[NodeCommand::DELETE_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp;
handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp;
handlers_[NodeCommand::REGISTER_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::DELETE_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
}

void AbstractNode::InitServerHandler() {


+ 4
- 1
mindspore/ccsrc/ps/core/abstract_node.h View File

@@ -169,7 +169,7 @@ class BACKEND_EXPORT AbstractNode : public Node {
void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Process the response messages about actor route table service.
void ProcessActorRouteServiceResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
@@ -225,6 +225,9 @@ class BACKEND_EXPORT AbstractNode : public Node {
void RegisterActorRouteTableRspHandler();
void InitServerHandler();

// Register collective communication initialization response methods.
virtual void RegisterInitCollectCommResphandler() {}

// when initializing the node, should initializing the node info.
void InitNodeInfo(const NodeRole &role);
// Initialize worker num and server num by cluster config.


+ 7
- 0
mindspore/ccsrc/ps/core/abstract_ps_node.cc View File

@@ -169,6 +169,13 @@ bool AbstractPSNode::HandleHeartbeatTimeout() {
stop_heartbeat_thread->detach();
return true;
}

void AbstractPSNode::RegisterInitCollectCommResphandler() {
handlers_[NodeCommand::SEND_HOST_NAME] = &AbstractPSNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::QUERY_HOST_NAMES] = &AbstractPSNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::SEND_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::QUERY_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/ps/core/abstract_ps_node.h View File

@@ -35,6 +35,9 @@ class AbstractPSNode : public AbstractNode {
void StartHeartbeatTimer();

private:
// Register collective communication initialization response methods.
void RegisterInitCollectCommResphandler() override;

// Indicate whether the heartbeat thread should be stopped.
std::atomic<bool> stop_heartbeat_{false};



+ 57
- 0
mindspore/ccsrc/ps/core/protos/comm.proto View File

@@ -49,6 +49,14 @@ enum NodeCommand {
DELETE_ACTOR_ROUTE = 16;
// Lookup address of the actor.
LOOKUP_ACTOR_ROUTE = 17;
// Send host name to scheduler.
SEND_HOST_NAME = 18;
// Query all worker nodes' host name.
QUERY_HOST_NAMES = 19;
// Send unique id used to initialize collective communication.
SEND_UNIQUE_ID = 20;
// Query unique id used to initialize collective communication.
QUERY_UNIQUE_ID = 21;
}

enum NodeRole {
@@ -241,3 +249,52 @@ message ActorAddress {
string ip = 2;
uint32 port = 3;
}

message GeneralQueryMessage {
// The unique node id.
string node_id = 1;
// The rank id of the node in the cluster.
uint32 rank_id = 2;
}

message SendHostHashNameMessage {
// The unique node id.
string node_id = 1;
// The rank id of the node in the cluster.
uint32 rank_id = 2;
// The host hash name of the node.
uint64 host_hash_name = 3;
}

message QueryHostHashNameRespMessage {
bool is_success = 1;
// The host hash names of all worker nodes.
repeated uint64 host_hash_names = 2;
}

message SendUniqueIDMessage {
// The unique node id.
string node_id = 1;
// The rank id of the node in the cluster.
uint32 rank_id = 2;
// The group name of goupt which need to initialize collective communication.
string group_name = 3;
// The unique id used to initialize collective communication.
bytes unique_id = 4;
}

message QueryUniqueIDMessage {
// The unique node id.
string node_id = 1;
// The rank id of the node in the cluster.
uint32 rank_id = 2;
// The group name of goupt which need to initialize collective communication.
string group_name = 3;
}


message QueryUniqueIDRespMessage {
bool is_success = 1;
// The unique id used to initialize collective communication.
bytes unique_id = 2;
}

+ 140
- 0
mindspore/ccsrc/ps/core/ps_scheduler_node.cc View File

@@ -50,6 +50,146 @@ void PSSchedulerNode::RunRecovery() {

MS_LOG(INFO) << "Scheduler recovery finish.";
}

void PSSchedulerNode::RegisterInitCollectCommServiceHandler() {
handlers_[NodeCommand::SEND_HOST_NAME] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendHostName);
handlers_[NodeCommand::QUERY_HOST_NAMES] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryHostNames);
handlers_[NodeCommand::SEND_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendUniqueID);
handlers_[NodeCommand::QUERY_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryUniqueID);
}

void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);

SendHostHashNameMessage send_host_name_msg;
send_host_name_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = send_host_name_msg.node_id();
uint32_t rank_id = send_host_name_msg.rank_id();
size_t host_hash_name = send_host_name_msg.host_hash_name();
MS_LOG(INFO) << "Received send host name request, node id: " << node_id << ", rank id: " << rank_id;

bool ret = false;
std::string error = "";
if (rank_id >= worker_num_) {
error = "The rank id: " + std::to_string(rank_id) + " should be less than: " + std::to_string(worker_num_);
MS_LOG(ERROR) << error;
} else {
host_hash_names_[rank_id] = host_hash_name;
(void)recv_rank_id_send_host_name_.insert(rank_id);
ret = true;
}

GeneralResponse(server, conn, meta, ret, error);
MS_LOG(INFO) << "Respond send host name request, node id: " << node_id << ", rank id: " << rank_id;
}

void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);

GeneralQueryMessage query_msg;
query_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
MS_LOG(INFO) << "Received query host name request, node id: " << node_id << ", rank id: " << rank_id;

bool is_success = recv_rank_id_send_host_name_.size() == host_hash_names_.size();
QueryHostHashNameRespMessage resp_msg;
resp_msg.set_is_success(is_success);
if (is_success) {
*resp_msg.mutable_host_hash_names() = {host_hash_names_.begin(), host_hash_names_.end()};
}
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(),
resp_msg.ByteSizeLong())) {
MS_LOG(ERROR) << "Scheduler failed to respond message.";
return;
}
MS_LOG(INFO) << "Respond query host name request, node id: " << node_id << ", rank id: " << rank_id;

if (is_success) {
(void)recv_rank_id_query_host_name_.insert(rank_id);

if (recv_rank_id_query_host_name_.size() == recv_rank_id_send_host_name_.size()) {
recv_rank_id_send_host_name_.clear();
recv_rank_id_query_host_name_.clear();
}
}
}

void PSSchedulerNode::ProcessSendUniqueID(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);

SendUniqueIDMessage send_unique_id_msg;
send_unique_id_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = send_unique_id_msg.node_id();
uint32_t rank_id = send_unique_id_msg.rank_id();
std::string group_name = send_unique_id_msg.group_name();
MS_LOG(INFO) << "Received send unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;

bool ret = false;
std::string error = "";
if (rank_id != 0) {
error = "The rank id: " + std::to_string(rank_id) + " of worker which sends unique id should be 0";
MS_LOG(ERROR) << error;
} else {
unique_id_group_[group_name] = send_unique_id_msg.unique_id();
ret = true;
}

GeneralResponse(server, conn, meta, ret, error);
MS_LOG(INFO) << "Respond send unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;
}

void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);

QueryUniqueIDMessage query_msg;
query_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
std::string group_name = query_msg.group_name();
MS_LOG(INFO) << "Received query unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;

auto iter = unique_id_group_.find(group_name);
bool is_success = (iter != unique_id_group_.end());

QueryUniqueIDRespMessage resp_msg;
resp_msg.set_is_success(is_success);
if (is_success) {
resp_msg.set_unique_id(iter->second);
}

if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(),
resp_msg.ByteSizeLong())) {
MS_LOG(ERROR) << "Scheduler failed to respond message.";
return;
}

MS_LOG(INFO) << "Respond query unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 38
- 1
mindspore/ccsrc/ps/core/ps_scheduler_node.h View File

@@ -17,6 +17,12 @@
#ifndef MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_

#include <map>
#include <memory>
#include <vector>
#include <set>
#include <string>

#include "ps/core/scheduler_node.h"
#include "ps/core/node_info.h"
#include "include/backend/visible.h"
@@ -29,7 +35,7 @@ namespace core {
// the registration request of alive nodes.
class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
public:
PSSchedulerNode() = default;
PSSchedulerNode() : worker_num_(ps::PSContext::instance()->worker_num()) { host_hash_names_.resize(worker_num_); }
~PSSchedulerNode() override = default;

protected:
@@ -40,6 +46,37 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
// Determine whether the registration request of the node should be rejected, the registration of the
// alive node should be rejected.
bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; }

// Register collective communication initialization service.
void RegisterInitCollectCommServiceHandler() override;

// Process message for sending node's host name.
void ProcessSendHostName(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Process message for querying all nodes' host name.
void ProcessQueryHostNames(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Process message for send unique id.
void ProcessSendUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Process message for querying unique id.
void ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Record received host hash name from workers.
std::vector<size_t> host_hash_names_;
// Record rank id of the nodes which sended host name.
std::set<uint32_t> recv_rank_id_send_host_name_;
// Record rank id of the nodes which queried host name.
std::set<uint32_t> recv_rank_id_query_host_name_;

// Record unique id of every group, key: group name, value: unique id.
std::map<std::string, std::string> unique_id_group_;

uint32_t worker_num_;
};
} // namespace core
} // namespace ps


+ 1
- 0
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -233,6 +233,7 @@ void SchedulerNode::InitCommandHandler() {
handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone;
handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
RegisterActorRouteTableServiceHandler();
RegisterInitCollectCommServiceHandler();
}

void SchedulerNode::RegisterActorRouteTableServiceHandler() {


+ 3
- 0
mindspore/ccsrc/ps/core/scheduler_node.h View File

@@ -89,6 +89,9 @@ class BACKEND_EXPORT SchedulerNode : public Node {
void RegisterActorRouteTableServiceHandler();
void InitializeActorRouteTableService();

// Register collective communication initialization service.
virtual void RegisterInitCollectCommServiceHandler() {}

const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);

void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,


+ 2
- 1
mindspore/ccsrc/runtime/collective/collective_communication_lib.cc View File

@@ -19,7 +19,7 @@
namespace mindspore {
namespace device {
bool CollectiveCommunicationLib::Finalize() {
if (!initialized_) {
if (!initialized_ || finalized_.load()) {
return true;
}

@@ -31,6 +31,7 @@ bool CollectiveCommunicationLib::Finalize() {
}
groups_.clear();
initialized_ = false;
finalized_ = true;
return true;
}



+ 15
- 1
mindspore/ccsrc/runtime/collective/collective_communication_lib.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_COMMUNICATION_LIB_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_COMMUNICATION_LIB_H_

#include <atomic>
#include <map>
#include <memory>
#include <vector>
@@ -44,7 +45,8 @@ enum CollectiveOpReduceType : int64_t {
// MsCollectiveCommLib which uses the host-side communication library developed by MindSpore.
class CollectiveCommunicationLib {
public:
CollectiveCommunicationLib() : initialized_(false), global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {}
CollectiveCommunicationLib()
: initialized_(false), finalized_(false), global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {}
virtual ~CollectiveCommunicationLib() { groups_.clear(); }

// Initialize collecitve communication library.
@@ -77,6 +79,15 @@ class CollectiveCommunicationLib {
// Return communication group pointer.
virtual CommunicationGroupPtr GetGroup(const std::string &group_name);

// AllGather host names of all nodes, used to initialize collective communication.
virtual bool AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) const { return true; }

// Broadcast the device root information to all nodes on host side, used to initialize collective communication.
virtual bool BroadcastUniqueID(const std::string &group_name, bool is_root_node, size_t root_info_size,
void *root_info) const {
return true;
}

// Primitive of collective operations.
virtual bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream = nullptr) {
@@ -122,6 +133,9 @@ class CollectiveCommunicationLib {
// Whether this collective communication library is initialized.
bool initialized_;

// Whether this collective communication library is finalized.
std::atomic_bool finalized_;

// The global group name.
std::string global_group_name_;



Loading…
Cancel
Save