Merge pull request !31811 from zyli2020/masterr1.7
| @@ -120,18 +120,14 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name, | |||
| // Step 3: Generate device information of the root node. | |||
| CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name); | |||
| MS_EXCEPTION_IF_NULL(group); | |||
| bool is_root_node = (group->GetGroupRank(global_rank_id_) == 0); | |||
| size_t root_info_size = 0; | |||
| void *root_info = group->GenerateRootInfo(&root_info_size); | |||
| MS_EXCEPTION_IF_NULL(root_info); | |||
| // Step 4: Broadcast the device root information to all nodes on host side. | |||
| if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt8, 0, | |||
| group_name)) { | |||
| if (!host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info)) { | |||
| MS_LOG(ERROR) << "Broadcast for device root info failed on the host side."; | |||
| if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) { | |||
| runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status( | |||
| runtime::recovery::RecoveryErrCode::kBroadcastUniqueIDFailed); | |||
| } | |||
| return false; | |||
| } | |||
| @@ -305,7 +301,7 @@ bool CollectiveManager::AssignLocalRank() { | |||
| // that local rank id won't repeat. | |||
| size_t host_hash = std::hash<std::string>()(host_name); | |||
| const uint32_t kGlobalRankSize = global_rank_size_; | |||
| size_t all_host_hashs[kGlobalRankSize]; | |||
| std::vector<size_t> all_host_hashs(kGlobalRankSize); | |||
| if (global_rank_id_ >= global_rank_size_) { | |||
| MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size " | |||
| << global_rank_size_; | |||
| @@ -314,13 +310,7 @@ bool CollectiveManager::AssignLocalRank() { | |||
| all_host_hashs[global_rank_id_] = host_hash; | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| // AllGather host names across the global communication group. | |||
| if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeUInt64, | |||
| host_global_group_name_)) { | |||
| if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) { | |||
| runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status( | |||
| runtime::recovery::RecoveryErrCode::kAllGatherHostNameFailed); | |||
| } | |||
| if (!host_comm_lib_instance_->AllGatherHostHashName(host_hash, &all_host_hashs)) { | |||
| MS_LOG(ERROR) << "AllGather for host names failed."; | |||
| return false; | |||
| } | |||
| @@ -34,6 +34,7 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_ | |||
| global_rank_id_ = global_rank; | |||
| global_rank_size_ = global_rank_size; | |||
| initialized_ = true; | |||
| finalized_ = false; | |||
| return true; | |||
| } | |||
| @@ -50,6 +51,167 @@ bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) const { | |||
| CHECK_IF_NULL(host_hash_names); | |||
| while (!SendHostHashName(host_hash_name)) { | |||
| MS_LOG(WARNING) << "Send host hash name to scheduler failed, retrying..."; | |||
| if (finalized_.load()) { | |||
| return false; | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration)); | |||
| } | |||
| while (!QueryHostHashNames(host_hash_names)) { | |||
| MS_LOG(WARNING) << "Query host hash names from scheduler failed, retrying..."; | |||
| if (finalized_.load()) { | |||
| return false; | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration)); | |||
| } | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::BroadcastUniqueID(const std::string &group_name, bool is_root_node, size_t root_info_size, | |||
| void *root_info) const { | |||
| CHECK_IF_NULL(root_info); | |||
| if (is_root_node) { | |||
| while (!SendUniqueID(group_name, root_info_size, root_info)) { | |||
| MS_LOG(WARNING) << "Send unique id to scheduler failed, retrying..."; | |||
| if (finalized_.load()) { | |||
| return false; | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration)); | |||
| } | |||
| return true; | |||
| } | |||
| while (!QueryUniqueID(group_name, root_info_size, root_info)) { | |||
| MS_LOG(WARNING) << "Query unique id from scheduler failed, retrying..."; | |||
| if (finalized_.load()) { | |||
| return false; | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration)); | |||
| } | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::SendHostHashName(size_t host_hash_name) const { | |||
| CHECK_IF_NULL(node_); | |||
| ps::core::SendHostHashNameMessage send_host_name_msg; | |||
| send_host_name_msg.set_node_id(node_->node_id()); | |||
| send_host_name_msg.set_rank_id(node_->rank_id()); | |||
| send_host_name_msg.set_host_hash_name(host_hash_name); | |||
| std::shared_ptr<std::vector<unsigned char>> output = nullptr; | |||
| if (!node_->SendToScheduler(send_host_name_msg.SerializeAsString().data(), | |||
| send_host_name_msg.SerializeAsString().size(), NodeCommand::SEND_HOST_NAME, &output)) { | |||
| MS_LOG(WARNING) << "Failed to send host hash name request to scheduler."; | |||
| return false; | |||
| } | |||
| ps::core::GeneralResponseMsg resp_msg; | |||
| CHECK_IF_NULL(output); | |||
| (void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size())); | |||
| if (!resp_msg.is_success()) { | |||
| MS_LOG(WARNING) << "Send host hash name to scheduler failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::QueryHostHashNames(std::vector<size_t> *host_hash_names) const { | |||
| CHECK_IF_NULL(host_hash_names); | |||
| CHECK_IF_NULL(node_); | |||
| ps::core::GeneralQueryMessage general_query_msg; | |||
| general_query_msg.set_node_id(node_->node_id()); | |||
| general_query_msg.set_rank_id(node_->rank_id()); | |||
| std::shared_ptr<std::vector<unsigned char>> output = nullptr; | |||
| if (!node_->SendToScheduler(general_query_msg.SerializeAsString().data(), | |||
| general_query_msg.SerializeAsString().size(), NodeCommand::QUERY_HOST_NAMES, &output)) { | |||
| MS_LOG(WARNING) << "Failed to send query host name request to scheduler."; | |||
| return false; | |||
| } | |||
| ps::core::QueryHostHashNameRespMessage resp_msg; | |||
| CHECK_IF_NULL(output); | |||
| (void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size())); | |||
| if (!resp_msg.is_success()) { | |||
| MS_LOG(INFO) << "Query host hash name from scheduer failed, maybe scheduler has not received all host names."; | |||
| return false; | |||
| } | |||
| if (host_hash_names->size() != IntToSize(resp_msg.host_hash_names_size())) { | |||
| MS_LOG(ERROR) << "The host_hash_names container size: " << host_hash_names->size() | |||
| << ", but received size: " << resp_msg.host_hash_names_size(); | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < host_hash_names->size(); i++) { | |||
| (*host_hash_names)[i] = resp_msg.host_hash_names()[i]; | |||
| } | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::SendUniqueID(const std::string &group_name, size_t root_info_size, | |||
| const void *root_info) const { | |||
| CHECK_IF_NULL(root_info); | |||
| CHECK_IF_NULL(node_); | |||
| ps::core::SendUniqueIDMessage send_unique_id_msg; | |||
| send_unique_id_msg.set_node_id(node_->node_id()); | |||
| send_unique_id_msg.set_rank_id(node_->rank_id()); | |||
| send_unique_id_msg.set_group_name(group_name); | |||
| send_unique_id_msg.set_unique_id(root_info, root_info_size); | |||
| std::shared_ptr<std::vector<unsigned char>> output = nullptr; | |||
| if (!node_->SendToScheduler(send_unique_id_msg.SerializeAsString().data(), | |||
| send_unique_id_msg.SerializeAsString().size(), NodeCommand::SEND_UNIQUE_ID, &output)) { | |||
| MS_LOG(WARNING) << "Failed to send unique id request to scheduler."; | |||
| return false; | |||
| } | |||
| ps::core::GeneralResponseMsg resp_msg; | |||
| CHECK_IF_NULL(output); | |||
| (void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size())); | |||
| if (!resp_msg.is_success()) { | |||
| MS_LOG(WARNING) << "Send unique id to scheduler failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) const { | |||
| CHECK_IF_NULL(root_info); | |||
| CHECK_IF_NULL(node_); | |||
| ps::core::QueryUniqueIDMessage query_unique_id_msg; | |||
| query_unique_id_msg.set_node_id(node_->node_id()); | |||
| query_unique_id_msg.set_rank_id(node_->rank_id()); | |||
| query_unique_id_msg.set_group_name(group_name); | |||
| std::shared_ptr<std::vector<unsigned char>> output = nullptr; | |||
| if (!node_->SendToScheduler(query_unique_id_msg.SerializeAsString().data(), | |||
| query_unique_id_msg.SerializeAsString().size(), NodeCommand::QUERY_UNIQUE_ID, &output)) { | |||
| MS_LOG(WARNING) << "Failed to send query unique id request to scheduler."; | |||
| return false; | |||
| } | |||
| ps::core::QueryUniqueIDRespMessage resp_msg; | |||
| CHECK_IF_NULL(output); | |||
| (void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size())); | |||
| if (!resp_msg.is_success()) { | |||
| MS_LOG(INFO) << "Query unique id from scheduer failed, maybe scheduler has not received unique id."; | |||
| return false; | |||
| } | |||
| auto ret = memcpy_s(root_info, root_info_size, resp_msg.unique_id().data(), resp_msg.unique_id().length()); | |||
| if (ret != EOK) { | |||
| MS_LOG(WARNING) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool MsCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &, void *) { | |||
| CHECK_IF_NULL(send_buff); | |||
| @@ -32,6 +32,10 @@ constexpr char kMSGlobalGroupName[] = "ms_world_group"; | |||
| using ClusterContext = mindspore::distributed::cluster::ClusterContext; | |||
| using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl; | |||
| using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo; | |||
| using ps::core::NodeCommand; | |||
| // The time interval for send info or query info between worker and scheduler. | |||
| constexpr uint32_t kWaitDuration = 3; | |||
| // The collective communication library for MindSpore self developed communication framework. | |||
| class MsCollectiveCommLib : public CollectiveCommunicationLib { | |||
| @@ -45,6 +49,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib { | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override; | |||
| bool AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) const override; | |||
| bool BroadcastUniqueID(const std::string &group_name, bool is_root_node, size_t root_info_size, | |||
| void *root_info) const override; | |||
| bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &group_name, void *stream = nullptr) override; | |||
| @@ -65,6 +74,18 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib { | |||
| MsCollectiveCommLib(); | |||
| ~MsCollectiveCommLib() override = default; | |||
| // Send host hash name to scheduler. | |||
| bool SendHostHashName(size_t host_hash_name) const; | |||
| // Query host hash names of all nodes from scheduler. | |||
| bool QueryHostHashNames(std::vector<size_t> *host_hash_names) const; | |||
| // Send unique id to scheduler. | |||
| bool SendUniqueID(const std::string &group_name, size_t root_info_size, const void *root_info) const; | |||
| // Query unique id from scheduler. | |||
| bool QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) const; | |||
| std::shared_ptr<ps::core::AbstractNode> node_; | |||
| }; | |||
| } // namespace cpu | |||
| @@ -802,8 +802,8 @@ void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &m | |||
| } | |||
| } | |||
| void AbstractNode::ProcessActorRouteServiceResp(const std::shared_ptr<MessageMeta> &meta, const void *data, | |||
| size_t size) { | |||
| void AbstractNode::ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data, | |||
| size_t size) { | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| std::lock_guard<std::mutex> lock(receive_messages_mutex_); | |||
| @@ -1300,12 +1300,13 @@ void AbstractNode::InitCommandHandler() { | |||
| handlers_[NodeCommand::SCALE_IN_DONE] = nullptr; | |||
| handlers_[NodeCommand::SEND_EVENT] = nullptr; | |||
| RegisterActorRouteTableRspHandler(); | |||
| RegisterInitCollectCommResphandler(); | |||
| } | |||
| void AbstractNode::RegisterActorRouteTableRspHandler() { | |||
| handlers_[NodeCommand::REGISTER_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp; | |||
| handlers_[NodeCommand::DELETE_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp; | |||
| handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp; | |||
| handlers_[NodeCommand::REGISTER_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp; | |||
| handlers_[NodeCommand::DELETE_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp; | |||
| handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp; | |||
| } | |||
| void AbstractNode::InitServerHandler() { | |||
| @@ -169,7 +169,7 @@ class BACKEND_EXPORT AbstractNode : public Node { | |||
| void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Process the response messages about actor route table service. | |||
| void ProcessActorRouteServiceResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| void ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, | |||
| const Protos &protos, const void *data, size_t size); | |||
| @@ -225,6 +225,9 @@ class BACKEND_EXPORT AbstractNode : public Node { | |||
| void RegisterActorRouteTableRspHandler(); | |||
| void InitServerHandler(); | |||
| // Register collective communication initialization response methods. | |||
| virtual void RegisterInitCollectCommResphandler() {} | |||
| // when initializing the node, should initializing the node info. | |||
| void InitNodeInfo(const NodeRole &role); | |||
| // Initialize worker num and server num by cluster config. | |||
| @@ -169,6 +169,13 @@ bool AbstractPSNode::HandleHeartbeatTimeout() { | |||
| stop_heartbeat_thread->detach(); | |||
| return true; | |||
| } | |||
| void AbstractPSNode::RegisterInitCollectCommResphandler() { | |||
| handlers_[NodeCommand::SEND_HOST_NAME] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| handlers_[NodeCommand::QUERY_HOST_NAMES] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| handlers_[NodeCommand::SEND_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| handlers_[NodeCommand::QUERY_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -35,6 +35,9 @@ class AbstractPSNode : public AbstractNode { | |||
| void StartHeartbeatTimer(); | |||
| private: | |||
| // Register collective communication initialization response methods. | |||
| void RegisterInitCollectCommResphandler() override; | |||
| // Indicate whether the heartbeat thread should be stopped. | |||
| std::atomic<bool> stop_heartbeat_{false}; | |||
| @@ -49,6 +49,14 @@ enum NodeCommand { | |||
| DELETE_ACTOR_ROUTE = 16; | |||
| // Lookup address of the actor. | |||
| LOOKUP_ACTOR_ROUTE = 17; | |||
| // Send host name to scheduler. | |||
| SEND_HOST_NAME = 18; | |||
| // Query all worker nodes' host name. | |||
| QUERY_HOST_NAMES = 19; | |||
| // Send unique id used to initialize collective communication. | |||
| SEND_UNIQUE_ID = 20; | |||
| // Query unique id used to initialize collective communication. | |||
| QUERY_UNIQUE_ID = 21; | |||
| } | |||
| enum NodeRole { | |||
| @@ -241,3 +249,52 @@ message ActorAddress { | |||
| string ip = 2; | |||
| uint32 port = 3; | |||
| } | |||
| message GeneralQueryMessage { | |||
| // The unique node id. | |||
| string node_id = 1; | |||
| // The rank id of the node in the cluster. | |||
| uint32 rank_id = 2; | |||
| } | |||
| message SendHostHashNameMessage { | |||
| // The unique node id. | |||
| string node_id = 1; | |||
| // The rank id of the node in the cluster. | |||
| uint32 rank_id = 2; | |||
| // The host hash name of the node. | |||
| uint64 host_hash_name = 3; | |||
| } | |||
| message QueryHostHashNameRespMessage { | |||
| bool is_success = 1; | |||
| // The host hash names of all worker nodes. | |||
| repeated uint64 host_hash_names = 2; | |||
| } | |||
| message SendUniqueIDMessage { | |||
| // The unique node id. | |||
| string node_id = 1; | |||
| // The rank id of the node in the cluster. | |||
| uint32 rank_id = 2; | |||
| // The group name of goupt which need to initialize collective communication. | |||
| string group_name = 3; | |||
| // The unique id used to initialize collective communication. | |||
| bytes unique_id = 4; | |||
| } | |||
| message QueryUniqueIDMessage { | |||
| // The unique node id. | |||
| string node_id = 1; | |||
| // The rank id of the node in the cluster. | |||
| uint32 rank_id = 2; | |||
| // The group name of goupt which need to initialize collective communication. | |||
| string group_name = 3; | |||
| } | |||
| message QueryUniqueIDRespMessage { | |||
| bool is_success = 1; | |||
| // The unique id used to initialize collective communication. | |||
| bytes unique_id = 2; | |||
| } | |||
| @@ -50,6 +50,146 @@ void PSSchedulerNode::RunRecovery() { | |||
| MS_LOG(INFO) << "Scheduler recovery finish."; | |||
| } | |||
| void PSSchedulerNode::RegisterInitCollectCommServiceHandler() { | |||
| handlers_[NodeCommand::SEND_HOST_NAME] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendHostName); | |||
| handlers_[NodeCommand::QUERY_HOST_NAMES] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryHostNames); | |||
| handlers_[NodeCommand::SEND_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendUniqueID); | |||
| handlers_[NodeCommand::QUERY_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryUniqueID); | |||
| } | |||
| void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &server, | |||
| const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(conn); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(meta); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data); | |||
| SendHostHashNameMessage send_host_name_msg; | |||
| send_host_name_msg.ParseFromArray(data, SizeToInt(size)); | |||
| std::string node_id = send_host_name_msg.node_id(); | |||
| uint32_t rank_id = send_host_name_msg.rank_id(); | |||
| size_t host_hash_name = send_host_name_msg.host_hash_name(); | |||
| MS_LOG(INFO) << "Received send host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| bool ret = false; | |||
| std::string error = ""; | |||
| if (rank_id >= worker_num_) { | |||
| error = "The rank id: " + std::to_string(rank_id) + " should be less than: " + std::to_string(worker_num_); | |||
| MS_LOG(ERROR) << error; | |||
| } else { | |||
| host_hash_names_[rank_id] = host_hash_name; | |||
| (void)recv_rank_id_send_host_name_.insert(rank_id); | |||
| ret = true; | |||
| } | |||
| GeneralResponse(server, conn, meta, ret, error); | |||
| MS_LOG(INFO) << "Respond send host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| } | |||
| void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &server, | |||
| const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(conn); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(meta); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data); | |||
| GeneralQueryMessage query_msg; | |||
| query_msg.ParseFromArray(data, SizeToInt(size)); | |||
| std::string node_id = query_msg.node_id(); | |||
| uint32_t rank_id = query_msg.rank_id(); | |||
| MS_LOG(INFO) << "Received query host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| bool is_success = recv_rank_id_send_host_name_.size() == host_hash_names_.size(); | |||
| QueryHostHashNameRespMessage resp_msg; | |||
| resp_msg.set_is_success(is_success); | |||
| if (is_success) { | |||
| *resp_msg.mutable_host_hash_names() = {host_hash_names_.begin(), host_hash_names_.end()}; | |||
| } | |||
| if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(), | |||
| resp_msg.ByteSizeLong())) { | |||
| MS_LOG(ERROR) << "Scheduler failed to respond message."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Respond query host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| if (is_success) { | |||
| (void)recv_rank_id_query_host_name_.insert(rank_id); | |||
| if (recv_rank_id_query_host_name_.size() == recv_rank_id_send_host_name_.size()) { | |||
| recv_rank_id_send_host_name_.clear(); | |||
| recv_rank_id_query_host_name_.clear(); | |||
| } | |||
| } | |||
| } | |||
| void PSSchedulerNode::ProcessSendUniqueID(const std::shared_ptr<TcpServer> &server, | |||
| const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(conn); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(meta); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data); | |||
| SendUniqueIDMessage send_unique_id_msg; | |||
| send_unique_id_msg.ParseFromArray(data, SizeToInt(size)); | |||
| std::string node_id = send_unique_id_msg.node_id(); | |||
| uint32_t rank_id = send_unique_id_msg.rank_id(); | |||
| std::string group_name = send_unique_id_msg.group_name(); | |||
| MS_LOG(INFO) << "Received send unique id request, group name: " << group_name << ", node id: " << node_id | |||
| << ", rank id: " << rank_id; | |||
| bool ret = false; | |||
| std::string error = ""; | |||
| if (rank_id != 0) { | |||
| error = "The rank id: " + std::to_string(rank_id) + " of worker which sends unique id should be 0"; | |||
| MS_LOG(ERROR) << error; | |||
| } else { | |||
| unique_id_group_[group_name] = send_unique_id_msg.unique_id(); | |||
| ret = true; | |||
| } | |||
| GeneralResponse(server, conn, meta, ret, error); | |||
| MS_LOG(INFO) << "Respond send unique id request, group name: " << group_name << ", node id: " << node_id | |||
| << ", rank id: " << rank_id; | |||
| } | |||
| void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server, | |||
| const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(conn); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(meta); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data); | |||
| QueryUniqueIDMessage query_msg; | |||
| query_msg.ParseFromArray(data, SizeToInt(size)); | |||
| std::string node_id = query_msg.node_id(); | |||
| uint32_t rank_id = query_msg.rank_id(); | |||
| std::string group_name = query_msg.group_name(); | |||
| MS_LOG(INFO) << "Received query unique id request, group name: " << group_name << ", node id: " << node_id | |||
| << ", rank id: " << rank_id; | |||
| auto iter = unique_id_group_.find(group_name); | |||
| bool is_success = (iter != unique_id_group_.end()); | |||
| QueryUniqueIDRespMessage resp_msg; | |||
| resp_msg.set_is_success(is_success); | |||
| if (is_success) { | |||
| resp_msg.set_unique_id(iter->second); | |||
| } | |||
| if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(), | |||
| resp_msg.ByteSizeLong())) { | |||
| MS_LOG(ERROR) << "Scheduler failed to respond message."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Respond query unique id request, group name: " << group_name << ", node id: " << node_id | |||
| << ", rank id: " << rank_id; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -17,6 +17,12 @@ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <string> | |||
| #include "ps/core/scheduler_node.h" | |||
| #include "ps/core/node_info.h" | |||
| #include "include/backend/visible.h" | |||
| @@ -29,7 +35,7 @@ namespace core { | |||
| // the registration request of alive nodes. | |||
| class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode { | |||
| public: | |||
| PSSchedulerNode() = default; | |||
| PSSchedulerNode() : worker_num_(ps::PSContext::instance()->worker_num()) { host_hash_names_.resize(worker_num_); } | |||
| ~PSSchedulerNode() override = default; | |||
| protected: | |||
| @@ -40,6 +46,37 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode { | |||
| // Determine whether the registration request of the node should be rejected, the registration of the | |||
| // alive node should be rejected. | |||
| bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; } | |||
| // Register collective communication initialization service. | |||
| void RegisterInitCollectCommServiceHandler() override; | |||
| // Process message for sending node's host name. | |||
| void ProcessSendHostName(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Process message for querying all nodes' host name. | |||
| void ProcessQueryHostNames(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Process message for send unique id. | |||
| void ProcessSendUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Process message for querying unique id. | |||
| void ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Record received host hash name from workers. | |||
| std::vector<size_t> host_hash_names_; | |||
| // Record rank id of the nodes which sended host name. | |||
| std::set<uint32_t> recv_rank_id_send_host_name_; | |||
| // Record rank id of the nodes which queried host name. | |||
| std::set<uint32_t> recv_rank_id_query_host_name_; | |||
| // Record unique id of every group, key: group name, value: unique id. | |||
| std::map<std::string, std::string> unique_id_group_; | |||
| uint32_t worker_num_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -233,6 +233,7 @@ void SchedulerNode::InitCommandHandler() { | |||
| handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone; | |||
| handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent; | |||
| RegisterActorRouteTableServiceHandler(); | |||
| RegisterInitCollectCommServiceHandler(); | |||
| } | |||
| void SchedulerNode::RegisterActorRouteTableServiceHandler() { | |||
| @@ -89,6 +89,9 @@ class BACKEND_EXPORT SchedulerNode : public Node { | |||
| void RegisterActorRouteTableServiceHandler(); | |||
| void InitializeActorRouteTableService(); | |||
| // Register collective communication initialization service. | |||
| virtual void RegisterInitCollectCommServiceHandler() {} | |||
| const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info); | |||
| void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| @@ -19,7 +19,7 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| bool CollectiveCommunicationLib::Finalize() { | |||
| if (!initialized_) { | |||
| if (!initialized_ || finalized_.load()) { | |||
| return true; | |||
| } | |||
| @@ -31,6 +31,7 @@ bool CollectiveCommunicationLib::Finalize() { | |||
| } | |||
| groups_.clear(); | |||
| initialized_ = false; | |||
| finalized_ = true; | |||
| return true; | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_COMMUNICATION_LIB_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_COMMUNICATION_LIB_H_ | |||
| #include <atomic> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <vector> | |||
| @@ -44,7 +45,8 @@ enum CollectiveOpReduceType : int64_t { | |||
| // MsCollectiveCommLib which uses the host-side communication library developed by MindSpore. | |||
| class CollectiveCommunicationLib { | |||
| public: | |||
| CollectiveCommunicationLib() : initialized_(false), global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {} | |||
| CollectiveCommunicationLib() | |||
| : initialized_(false), finalized_(false), global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {} | |||
| virtual ~CollectiveCommunicationLib() { groups_.clear(); } | |||
| // Initialize collecitve communication library. | |||
| @@ -77,6 +79,15 @@ class CollectiveCommunicationLib { | |||
| // Return communication group pointer. | |||
| virtual CommunicationGroupPtr GetGroup(const std::string &group_name); | |||
| // AllGather host names of all nodes, used to initialize collective communication. | |||
| virtual bool AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) const { return true; } | |||
| // Broadcast the device root information to all nodes on host side, used to initialize collective communication. | |||
| virtual bool BroadcastUniqueID(const std::string &group_name, bool is_root_node, size_t root_info_size, | |||
| void *root_info) const { | |||
| return true; | |||
| } | |||
| // Primitive of collective operations. | |||
| virtual bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, | |||
| const std::string &group_name, void *stream = nullptr) { | |||
| @@ -122,6 +133,9 @@ class CollectiveCommunicationLib { | |||
| // Whether this collective communication library is initialized. | |||
| bool initialized_; | |||
| // Whether this collective communication library is finalized. | |||
| std::atomic_bool finalized_; | |||
| // The global group name. | |||
| std::string global_group_name_; | |||