From: @anancds Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -12,7 +12,9 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") | |||
| endif() | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | |||
| endif () | |||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | |||
| @@ -109,6 +109,19 @@ std::string CommUtil::GenerateUUID() { | |||
| return ss.str(); | |||
| } | |||
| std::string CommUtil::NodeRoleToString(const NodeRole &role) { | |||
| switch (role) { | |||
| case NodeRole::SCHEDULER: | |||
| return "SCHEDULER"; | |||
| case NodeRole::SERVER: | |||
| return "SERVER"; | |||
| case NodeRole::WORKER: | |||
| return "WORKER"; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; | |||
| } | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -41,11 +41,13 @@ | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <functional> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <random> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -63,7 +65,9 @@ class CommUtil { | |||
| static bool CheckIp(const std::string &ip); | |||
| static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | |||
| static std::string GenerateUUID(); | |||
| static std::string NodeRoleToString(const NodeRole &role); | |||
| private: | |||
| static std::random_device rd; | |||
| static std::mt19937_64 gen; | |||
| static std::uniform_int_distribution<> dis; | |||
| @@ -0,0 +1,158 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void Node::Heartbeat(const std::shared_ptr<TcpClient> &client) { | |||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||
| << " begin send heartbeat to the scheduler!"; | |||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | |||
| while (!is_finish_.load()) { | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(heartbeat_message.SerializeAsString()); | |||
| SendMessageAsync(client, message); | |||
| } | |||
| }); | |||
| heart_beat_thread_->detach(); | |||
| } | |||
| void Node::ProcessHeartbeatResp(const CommMessage &message) { | |||
| HeartbeatRespMessage heartbeat_resp_message; | |||
| heartbeat_resp_message.ParseFromString(message.data()); | |||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||
| if (is_ready_.load()) { | |||
| wait_start_cond_.notify_all(); | |||
| } | |||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | |||
| if (is_finish_.load()) { | |||
| wait_finish_cond_.notify_all(); | |||
| } | |||
| is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | |||
| if (is_timeout_ && on_node_event_message_) { | |||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | |||
| } | |||
| } | |||
| void Node::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FETCH_SERVER); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| SendMessageSync(client, message); | |||
| } | |||
| void Node::ProcessFetchServersResp(const CommMessage &message) { | |||
| FetchServersRespMessage fetch_servers_resp_message; | |||
| fetch_servers_resp_message.ParseFromString(message.data()); | |||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | |||
| server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port()); | |||
| } | |||
| MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size(); | |||
| } | |||
| std::string Node::node_id() const { return node_info_.node_id_; } | |||
| uint32_t Node::rank_id() const { return node_info_.rank_id_; } | |||
| void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { | |||
| on_node_event_message_ = on_node_event_message; | |||
| } | |||
| void Node::Wait(uint64_t request_id) { | |||
| std::unique_lock<std::mutex> lock(message_mutex_); | |||
| message_tracker_cond_.wait(lock, [&] { | |||
| bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; | |||
| if (ret) { | |||
| MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id; | |||
| message_tracker_.erase(request_id); | |||
| } | |||
| return ret; | |||
| }); | |||
| } | |||
| void Node::Disconnect(const std::shared_ptr<TcpClient> &client) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FINISH); | |||
| FinishMessage finish_message; | |||
| finish_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(finish_message.SerializeAsString()); | |||
| SendMessageSync(client, message); | |||
| WaitForDisconnect(); | |||
| } | |||
| void Node::WaitForStart() { | |||
| std::unique_lock<std::mutex> lock(wait_start_mutex_); | |||
| wait_start_cond_.wait(lock, [&] { | |||
| if (is_ready_.load()) { | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!"; | |||
| } | |||
| return is_ready_.load(); | |||
| }); | |||
| } | |||
| void Node::WaitForDisconnect() { | |||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | |||
| wait_finish_cond_.wait(lock, [&] { | |||
| if (is_finish_.load()) { | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||
| } | |||
| return is_finish_.load(); | |||
| }); | |||
| } | |||
| void Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| Wait(request_id); | |||
| } | |||
| void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| } | |||
| void Node::NotifyMessageArrival(const CommMessage &message) { | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| uint64_t request_id = message_meta.request_id(); | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,102 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_NODE_H_ | |||
| #include <atomic> | |||
| #include <cstdlib> | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <condition_variable> | |||
| #include <utility> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/node_info.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class Node { | |||
| public: | |||
| Node() | |||
| : is_ready_(false), | |||
| is_finish_(false), | |||
| is_timeout_(false), | |||
| is_already_stopped_(true), | |||
| next_request_id_(0), | |||
| heart_beat_thread_(nullptr) {} | |||
| virtual ~Node() = default; | |||
| using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | |||
| void set_callback(const OnNodeEventMessage &on_node_event_message); | |||
| std::string node_id() const; | |||
| uint32_t rank_id() const; | |||
| void Wait(uint64_t request_id); | |||
| protected: | |||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessHeartbeatResp(const CommMessage &message); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessFetchServersResp(const CommMessage &message); | |||
| void Disconnect(const std::shared_ptr<TcpClient> &client); | |||
| void WaitForStart(); | |||
| void WaitForDisconnect(); | |||
| void SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| NodeInfo node_info_; | |||
| std::atomic<bool> is_ready_; | |||
| std::atomic<bool> is_finish_; | |||
| std::atomic<bool> is_timeout_; | |||
| std::atomic<bool> is_already_stopped_; | |||
| std::atomic_uint64_t next_request_id_; | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| OnNodeEventMessage on_node_event_message_; | |||
| // rank_id-><ip, port> | |||
| std::unordered_map<int, std::pair<std::string, uint16_t>> server_rank_ids_; | |||
| // timestamp-><expected responses, actual responses> | |||
| std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | |||
| std::mutex message_mutex_; | |||
| std::condition_variable message_tracker_cond_; | |||
| std::mutex wait_finish_mutex_; | |||
| std::condition_variable wait_finish_cond_; | |||
| std::mutex wait_start_mutex_; | |||
| std::condition_variable wait_start_cond_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_ | |||
| #include <string> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum NodeEvent { NODE_TIMEOUT = 0 }; | |||
| struct NodeInfo { | |||
| NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | |||
| // ip | |||
| std::string ip_; | |||
| // the port of this node | |||
| uint16_t port_; | |||
| // the current Node unique id:0,1,2... | |||
| std::string node_id_; | |||
| // the role of the node: worker,server,scheduler | |||
| NodeRole node_role_; | |||
| // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] | |||
| uint32_t rank_id_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_ | |||
| @@ -0,0 +1,137 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/node_manager.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void NodeManager::InitNodeNum() { total_node_num_ = ClusterConfig::server_num() + ClusterConfig::worker_num(); } | |||
| int NodeManager::NextRankId(const RegisterMessage ®ister_message) { | |||
| std::lock_guard<std::mutex> lock(assign_rank_id_mutex_); | |||
| int rank_id = -1; | |||
| const std::string &node_id = register_message.node_id(); | |||
| if (nodes_info_.find(node_id) != nodes_info_.end()) { | |||
| rank_id = nodes_info_[node_id].rank_id_; | |||
| MS_LOG(INFO) << "The node id: " << node_id << " is already assigned!"; | |||
| return rank_id; | |||
| } | |||
| if (register_message.role() == NodeRole::SERVER) { | |||
| const std::string &ip = register_message.ip(); | |||
| uint32_t port = register_message.port(); | |||
| rank_id = ++next_server_rank_id_; | |||
| NodeInfo node_info; | |||
| node_info.node_role_ = NodeRole::SERVER; | |||
| node_info.node_id_ = node_id; | |||
| node_info.rank_id_ = rank_id; | |||
| node_info.ip_ = ip; | |||
| node_info.port_ = port; | |||
| nodes_info_[node_id] = node_info; | |||
| MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port | |||
| << " assign rank id:" << rank_id; | |||
| } else if (register_message.role() == NodeRole::WORKER) { | |||
| rank_id = ++next_worker_rank_id_; | |||
| NodeInfo node_info; | |||
| node_info.node_role_ = NodeRole::WORKER; | |||
| node_info.node_id_ = node_id; | |||
| node_info.rank_id_ = rank_id; | |||
| nodes_info_[node_id] = node_info; | |||
| MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id; | |||
| } | |||
| return rank_id; | |||
| } | |||
| void NodeManager::UpdateHeartbeat(const std::string &node_id) { | |||
| std::lock_guard<std::mutex> lock(heartbeat_mutex_); | |||
| NodeInfo node_info = nodes_info_[node_id]; | |||
| struct timeval current_time {}; | |||
| (void)gettimeofday(¤t_time, nullptr); | |||
| heartbeats_[node_id] = current_time; | |||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id | |||
| << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; | |||
| } | |||
| std::vector<ServersMeta> NodeManager::FetchServersMeta() { | |||
| std::vector<ServersMeta> servers_meta_list; | |||
| for (auto it = nodes_info_.begin(); it != nodes_info_.end(); ++it) { | |||
| if (it->second.node_role_ == NodeRole::SERVER) { | |||
| ServersMeta servers_meta; | |||
| servers_meta.set_rank_id(it->second.rank_id_); | |||
| servers_meta.set_ip(it->second.ip_); | |||
| servers_meta.set_port(it->second.port_); | |||
| servers_meta_list.push_back(servers_meta); | |||
| } | |||
| } | |||
| return servers_meta_list; | |||
| } | |||
| void NodeManager::UpdateClusterState() { | |||
| // 1. update cluster timeout state | |||
| struct timeval current_time {}; | |||
| (void)gettimeofday(¤t_time, nullptr); | |||
| timeout_nodes_info_.clear(); | |||
| for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) { | |||
| if (it->second.tv_sec + ClusterConfig::heartbeat_timeout() < current_time.tv_sec) { | |||
| MS_LOG(ERROR) << "The node id:" << it->first << " is timeout!"; | |||
| timeout_nodes_info_[it->first] = nodes_info_[it->first]; | |||
| } | |||
| } | |||
| if (!timeout_nodes_info_.empty()) { | |||
| is_cluster_timeout_ = true; | |||
| for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) { | |||
| finish_nodes_id_.insert(it->first); | |||
| } | |||
| } | |||
| // 2. update cluster finish state | |||
| if (finish_nodes_id_.size() == total_node_num_) { | |||
| is_cluster_finish_ = true; | |||
| is_cluster_ready_ = true; | |||
| } | |||
| // 3. update cluster ready state | |||
| if (nodes_info_.size() == total_node_num_) { | |||
| is_cluster_ready_ = true; | |||
| } | |||
| } | |||
| void NodeManager::CheckClusterTimeout() { | |||
| if (total_node_num_ != nodes_info_.size()) { | |||
| MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout() | |||
| << " seconds,so finish the cluster"; | |||
| is_cluster_timeout_ = true; | |||
| } | |||
| } | |||
| void NodeManager::AddFinishNode(const FinishMessage &finish_message) { | |||
| finish_nodes_id_.insert(finish_message.node_id()); | |||
| } | |||
| std::unordered_map<std::string, NodeInfo> NodeManager::nodes_info() { return nodes_info_; } | |||
| bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); } | |||
| bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); } | |||
| bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef RPC_CLUSTER_MANAGER_H | |||
| #define RPC_CLUSTER_MANAGER_H | |||
| #include <atomic> | |||
| #include <cstdlib> | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <condition_variable> | |||
| #include <unordered_set> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/node.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class NodeManager { | |||
| public: | |||
| NodeManager() | |||
| : is_cluster_ready_(false), | |||
| is_cluster_finish_(false), | |||
| is_cluster_timeout_(false), | |||
| total_node_num_(0), | |||
| next_worker_rank_id_(-1), | |||
| next_server_rank_id_(-1) {} | |||
| virtual ~NodeManager() = default; | |||
| enum ClusterState { STARTING, STARTED, FAILED, STOPPING, STOPPED }; | |||
| void InitNodeNum(); | |||
| int NextRankId(const RegisterMessage ®ister_message); | |||
| void UpdateHeartbeat(const std::string &node_id); | |||
| std::vector<ServersMeta> FetchServersMeta(); | |||
| void UpdateClusterState(); | |||
| void CheckClusterTimeout(); | |||
| void AddFinishNode(const FinishMessage &finish_message); | |||
| std::unordered_map<std::string, NodeInfo> nodes_info(); | |||
| bool is_cluster_ready(); | |||
| bool is_cluster_finish(); | |||
| bool is_cluster_timeout(); | |||
| private: | |||
| std::atomic<bool> is_cluster_ready_; | |||
| std::atomic<bool> is_cluster_finish_; | |||
| std::atomic<bool> is_cluster_timeout_; | |||
| uint32_t total_node_num_; | |||
| std::atomic<int> next_worker_rank_id_; | |||
| std::atomic<int> next_server_rank_id_; | |||
| // worker nodes and server nodes | |||
| std::unordered_map<std::string, NodeInfo> nodes_info_; | |||
| std::mutex assign_rank_id_mutex_; | |||
| std::mutex heartbeat_mutex_; | |||
| std::unordered_map<std::string, timeval> heartbeats_; | |||
| // timeout nodes | |||
| std::unordered_map<std::string, NodeInfo> timeout_nodes_info_; | |||
| std::unordered_set<std::string> finish_nodes_id_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // RPC_CLUSTER_MANAGER_H | |||
| @@ -25,6 +25,7 @@ enum NodeCommand { | |||
| HEARTBEAT = 2; | |||
| SEND_DATA = 3; | |||
| FETCH_SERVER = 4; | |||
| FINISH = 5; | |||
| } | |||
| enum NodeRole { | |||
| @@ -65,6 +66,7 @@ message HeartbeatRespMessage { | |||
| // Is the entire system ready to use. | |||
| bool is_cluster_ready = 1; | |||
| bool is_cluster_finish = 2; | |||
| bool is_cluster_timeout = 3; | |||
| } | |||
| message FetchServersRespMessage { | |||
| @@ -78,6 +80,11 @@ message ServersMeta { | |||
| } | |||
| message FinishMessage { | |||
| // the current Node unique id:0,1,2... | |||
| string node_id = 1; | |||
| } | |||
| message CommMessage { | |||
| MessageMeta pb_meta = 1; | |||
| bytes data = 2; | |||
| @@ -32,6 +32,7 @@ | |||
| #include <atomic> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| namespace mindspore { | |||
| @@ -85,6 +85,8 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon | |||
| this->client_accept_ = client_accept; | |||
| } | |||
| void TcpServer::set_timer_once_callback(const OnTimerOnce &timer) { on_timer_once_callback_ = timer; } | |||
| void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | |||
| void TcpServer::Init() { | |||
| @@ -165,7 +167,21 @@ void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { | |||
| struct timeval timeout {}; | |||
| timeout.tv_sec = time; | |||
| timeout.tv_usec = 0; | |||
| ev = evtimer_new(base_, TimerCallback, this); | |||
| ev = evtimer_new(base_, TimerOnceCallback, this); | |||
| MS_EXCEPTION_IF_NULL(ev); | |||
| evtimer_add(ev, &timeout); | |||
| } | |||
| void TcpServer::StartTimer(const uint32_t &time) { | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| struct event *ev = nullptr; | |||
| if (time == 0) { | |||
| MS_LOG(EXCEPTION) << "The time should not be 0!"; | |||
| } | |||
| struct timeval timeout {}; | |||
| timeout.tv_sec = time; | |||
| timeout.tv_usec = 0; | |||
| ev = event_new(base_, -1, EV_PERSIST, TimerCallback, this); | |||
| MS_EXCEPTION_IF_NULL(ev); | |||
| evtimer_add(ev, &timeout); | |||
| } | |||
| @@ -321,7 +337,15 @@ void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_server = reinterpret_cast<TcpServer *>(arg); | |||
| if (tcp_server->on_timer_callback_) { | |||
| tcp_server->on_timer_callback_(*tcp_server); | |||
| tcp_server->on_timer_callback_(); | |||
| } | |||
| } | |||
| void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_server = reinterpret_cast<TcpServer *>(arg); | |||
| if (tcp_server->on_timer_once_callback_) { | |||
| tcp_server->on_timer_once_callback_(*tcp_server); | |||
| } | |||
| } | |||
| @@ -337,6 +361,8 @@ void TcpServer::SendMessage(const CommMessage &message) { | |||
| uint16_t TcpServer::BoundPort() const { return server_port_; } | |||
| std::string TcpServer::BoundIp() const { return server_address_; } | |||
| int TcpServer::ConnectionNum() const { return connections_.size(); } | |||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | |||
| @@ -35,6 +35,7 @@ | |||
| #include <atomic> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/tcp_message_handler.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -71,18 +72,21 @@ class TcpServer { | |||
| using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; | |||
| using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; | |||
| using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; | |||
| using OnTimer = std::function<void(const TcpServer &)>; | |||
| using OnTimerOnce = std::function<void(const TcpServer &)>; | |||
| using OnTimer = std::function<void()>; | |||
| explicit TcpServer(const std::string &address, std::uint16_t port); | |||
| virtual ~TcpServer(); | |||
| void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | |||
| const OnAccepted &client_accept); | |||
| void set_timer_once_callback(const OnTimerOnce &timer); | |||
| void set_timer_callback(const OnTimer &timer); | |||
| void Init(); | |||
| void Start(); | |||
| void StartWithNoBlock(); | |||
| void StartTimerOnlyOnce(const uint32_t &time); | |||
| void StartTimer(const uint32_t &time); | |||
| void Stop(); | |||
| void SendToAllClients(const char *data, size_t len); | |||
| void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); | |||
| @@ -92,6 +96,7 @@ class TcpServer { | |||
| void SendMessage(const TcpConnection &conn, const CommMessage &message); | |||
| void SendMessage(const CommMessage &message); | |||
| uint16_t BoundPort() const; | |||
| std::string BoundIp() const; | |||
| int ConnectionNum() const; | |||
| const std::map<evutil_socket_t, const TcpConnection *> &Connections() const; | |||
| @@ -102,6 +107,7 @@ class TcpServer { | |||
| static void ReadCallback(struct bufferevent *, void *connection); | |||
| static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | |||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | |||
| static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg); | |||
| virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | |||
| struct event_base *base_; | |||
| @@ -117,6 +123,7 @@ class TcpServer { | |||
| OnAccepted client_accept_; | |||
| std::recursive_mutex connection_mutex_; | |||
| OnServerReceiveMessage message_callback_; | |||
| OnTimerOnce on_timer_once_callback_; | |||
| OnTimer on_timer_callback_; | |||
| }; | |||
| } // namespace core | |||