| @@ -19,6 +19,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") | |||
| endif () | |||
| if (NOT ENABLE_D) | |||
| @@ -74,30 +74,161 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me | |||
| on_node_event_message_ = on_node_event_message; | |||
| } | |||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) { | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| return SendMessageSync(client, comm_message); | |||
| } | |||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| if (rank_ids.size() != data.size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; | |||
| } | |||
| for (size_t it = 0; it < rank_ids.size(); ++it) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(data.at(it)); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| *comm_message_resp = res[rank_id]; | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| client->SendMessage(comm_message); | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||
| const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; | |||
| } | |||
| size_t len = rank_ids.size(); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| for (size_t it = 0; it < len; ++it) { | |||
| comm_message_resp->at(it) = &res[rank_ids.at(it)]; | |||
| } | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| for (size_t it = 0; it < len; ++it) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(data.at(it)); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(message_tracker_mutex_); | |||
| bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; | |||
| return ret; | |||
| }); | |||
| message_tracker_.erase(request_id); | |||
| return res; | |||
| } | |||
| void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) { | |||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||
| << " begin send heartbeat to the scheduler!"; | |||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | |||
| while (!is_finish_.load()) { | |||
| Heartbeat(client); | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(heartbeat_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| } | |||
| } | |||
| }); | |||
| heart_beat_thread_->detach(); | |||
| } | |||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||
| heartbeat_message.set_is_node_finish(is_node_finish); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(heartbeat_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| } | |||
| } | |||
| void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||
| HeartbeatRespMessage heartbeat_resp_message; | |||
| heartbeat_resp_message.ParseFromString(message.data()); | |||
| @@ -106,8 +237,9 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||
| wait_start_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; | |||
| } | |||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | |||
| if (is_finish_.load()) { | |||
| if (heartbeat_resp_message.is_cluster_finish()) { | |||
| Heartbeat(client_to_scheduler_, true); | |||
| is_finish_ = true; | |||
| wait_finish_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; | |||
| } | |||
| @@ -115,6 +247,10 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||
| if (is_timeout_ && on_node_event_message_) { | |||
| is_ready_ = true; | |||
| wait_start_cond_.notify_all(); | |||
| on_node_event_message_(NodeEvent::CLUSTER_TIMEOUT); | |||
| } | |||
| if (heartbeat_resp_message.is_node_timeout() && on_node_event_message_) { | |||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | |||
| } | |||
| } | |||
| @@ -207,6 +343,101 @@ bool AbstractNode::InitClientToScheduler() { | |||
| }); | |||
| return client_to_scheduler_->WaitConnected(); | |||
| } | |||
| const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &rank_id) { | |||
| std::lock_guard<std::mutex> lock(client_mutex_); | |||
| if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { | |||
| return connected_nodes_[rank_id]; | |||
| } else { | |||
| if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { | |||
| MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!"; | |||
| } | |||
| std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; | |||
| uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; | |||
| auto client = std::make_shared<TcpClient>(ip, port); | |||
| client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendDataResp(message); | |||
| RunMessageCallback(message.pb_meta().request_id()); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| NotifyMessageArrival(message); | |||
| }); | |||
| client->Init(); | |||
| connected_nodes_[rank_id] = client; | |||
| return connected_nodes_[rank_id]; | |||
| } | |||
| } | |||
| bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| return Wait(request_id, timeout); | |||
| } | |||
| void AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| } | |||
| void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(receive_messages_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| const uint32_t &rank_id = message_meta.rank_id(); | |||
| const uint64_t request_id = message_meta.request_id(); | |||
| auto it = receive_messages_.find(request_id); | |||
| if (it != receive_messages_.end()) { | |||
| it->second.insert(std::make_pair(rank_id, message)); | |||
| } else { | |||
| std::unordered_map<uint32_t, CommMessage> res; | |||
| res.insert(std::make_pair(rank_id, message)); | |||
| receive_messages_[request_id] = res; | |||
| } | |||
| } | |||
| void AbstractNode::RunMessageCallback(const uint64_t &request_id) { | |||
| message_callbacks_mutex_.lock(); | |||
| // When receiving a message's response, Then compare with the desired number of responses, | |||
| // If they are equal, then call the callback function | |||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { | |||
| auto it = message_callbacks_.find(request_id); | |||
| if (it != message_callbacks_.end()) { | |||
| message_callbacks_mutex_.unlock(); | |||
| if (it->second) { | |||
| it->second(); | |||
| } | |||
| message_callbacks_mutex_.lock(); | |||
| message_callbacks_.erase(it); | |||
| } | |||
| } | |||
| message_callbacks_mutex_.unlock(); | |||
| } | |||
| void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { | |||
| if (!message_callback) { | |||
| return; | |||
| } | |||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | |||
| message_callbacks_[request_id] = message_callback; | |||
| } | |||
| void AbstractNode::NotifyMessageArrival(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| uint64_t request_id = message_meta.request_id(); | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,9 @@ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "ps/core/node.h" | |||
| @@ -34,21 +37,60 @@ class AbstractNode : public Node { | |||
| bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void set_event_callback(const OnNodeEventMessage &on_node_event_message); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| protected: | |||
| void Register(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessRegisterResp(const CommMessage &message); | |||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | |||
| void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); | |||
| void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); | |||
| void ProcessHeartbeatResp(const CommMessage &message); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessFetchServersResp(const CommMessage &message); | |||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||
| bool WaitForDisconnect(const uint32_t &timeout); | |||
| bool InitClientToScheduler(); | |||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | |||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| void ProcessSendDataResp(const CommMessage &message); | |||
| void RunMessageCallback(const uint64_t &request_id); | |||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| std::unique_ptr<std::thread> client_to_scheduler_thread_; | |||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||
| OnNodeEventMessage on_node_event_message_; | |||
| // the map's key is: <node_role,rank_id>, the map's value is: <ip, port> | |||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | |||
| std::mutex client_mutex_; | |||
| // the map's key is: rank_id | |||
| std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_; | |||
| // the map's key is: request_id, the map's value is: <expected responses, actual responses> | |||
| std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | |||
| std::mutex message_tracker_mutex_; | |||
| std::condition_variable message_tracker_cond_; | |||
| // the map's key is: request_id, the map's value is:<rank_id, CommMessage> | |||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | |||
| std::mutex receive_messages_mutex_; | |||
| // the map's key is: request_id | |||
| std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | |||
| std::mutex message_callbacks_mutex_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -25,131 +25,6 @@ uint32_t Node::rank_id() const { return node_info_.rank_id_; } | |||
| NodeRole Node::role() const { return node_info_.node_role_; } | |||
| bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(message_tracker_mutex_); | |||
| bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; | |||
| return ret; | |||
| }); | |||
| message_tracker_.erase(request_id); | |||
| return res; | |||
| } | |||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| return SendMessageSync(client, comm_message); | |||
| } | |||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| if (rank_ids.size() != data.size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; | |||
| } | |||
| for (size_t it = 0; it < rank_ids.size(); ++it) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(data.at(it)); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| *comm_message_resp = res[rank_id]; | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| client->SendMessage(comm_message); | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; | |||
| } | |||
| size_t len = rank_ids.size(); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| for (size_t it = 0; it < len; ++it) { | |||
| comm_message_resp->at(it) = &res[rank_ids.at(it)]; | |||
| } | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| for (size_t it = 0; it < len; ++it) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(data.at(it)); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool Node::WaitForStart(const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(wait_start_mutex_); | |||
| bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| @@ -161,101 +36,6 @@ bool Node::WaitForStart(const uint32_t &timeout) { | |||
| }); | |||
| return res; | |||
| } | |||
| bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| return Wait(request_id, timeout); | |||
| } | |||
| void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| } | |||
| const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) { | |||
| std::lock_guard<std::mutex> lock(client_mutex_); | |||
| if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { | |||
| return connected_nodes_[rank_id]; | |||
| } else { | |||
| if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { | |||
| MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!"; | |||
| } | |||
| std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; | |||
| uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; | |||
| auto client = std::make_shared<TcpClient>(ip, port); | |||
| client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendDataResp(message); | |||
| RunMessageCallback(message.pb_meta().request_id()); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| NotifyMessageArrival(message); | |||
| }); | |||
| client->Init(); | |||
| connected_nodes_[rank_id] = client; | |||
| return connected_nodes_[rank_id]; | |||
| } | |||
| } | |||
| void Node::ProcessSendDataResp(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(receive_messages_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| const uint32_t &rank_id = message_meta.rank_id(); | |||
| const uint64_t request_id = message_meta.request_id(); | |||
| auto it = receive_messages_.find(request_id); | |||
| if (it != receive_messages_.end()) { | |||
| it->second.insert(std::make_pair(rank_id, message)); | |||
| } else { | |||
| std::unordered_map<uint32_t, CommMessage> res; | |||
| res.insert(std::make_pair(rank_id, message)); | |||
| receive_messages_[request_id] = res; | |||
| } | |||
| } | |||
| void Node::RunMessageCallback(const uint64_t &request_id) { | |||
| message_callbacks_mutex_.lock(); | |||
| // When receiving a message's response, Then compare with the desired number of responses, | |||
| // If they are equal, then call the callback function | |||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { | |||
| auto it = message_callbacks_.find(request_id); | |||
| if (it != message_callbacks_.end()) { | |||
| message_callbacks_mutex_.unlock(); | |||
| if (it->second) { | |||
| it->second(); | |||
| } | |||
| message_callbacks_mutex_.lock(); | |||
| message_callbacks_.erase(it); | |||
| } | |||
| } | |||
| message_callbacks_mutex_.unlock(); | |||
| } | |||
| void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { | |||
| if (!message_callback) { | |||
| return; | |||
| } | |||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | |||
| message_callbacks_[request_id] = message_callback; | |||
| } | |||
| void Node::NotifyMessageArrival(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| uint64_t request_id = message_meta.request_id(); | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -29,7 +29,6 @@ | |||
| #include <condition_variable> | |||
| #include <utility> | |||
| #include <tuple> | |||
| #include <map> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| @@ -66,28 +65,8 @@ class Node { | |||
| uint32_t rank_id() const; | |||
| NodeRole role() const; | |||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| protected: | |||
| bool WaitForStart(const uint32_t &timeout); | |||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | |||
| void ProcessSendDataResp(const CommMessage &message); | |||
| void RunMessageCallback(const uint64_t &request_id); | |||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| NodeInfo node_info_; | |||
| std::atomic<bool> is_ready_; | |||
| @@ -97,28 +76,11 @@ class Node { | |||
| std::atomic<bool> is_already_finished_; | |||
| std::atomic_uint64_t next_request_id_; | |||
| // <NodeRole,rank_id>-><ip, port> | |||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | |||
| // rank_id->tcpclient | |||
| std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_; | |||
| // request_id-><expected responses, actual responses> | |||
| std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | |||
| std::mutex message_tracker_mutex_; | |||
| std::condition_variable message_tracker_cond_; | |||
| std::mutex wait_finish_mutex_; | |||
| std::condition_variable wait_finish_cond_; | |||
| std::mutex wait_start_mutex_; | |||
| std::condition_variable wait_start_cond_; | |||
| std::mutex wait_finish_mutex_; | |||
| std::condition_variable wait_finish_cond_; | |||
| std::mutex finish_mutex_; | |||
| std::mutex client_mutex_; | |||
| // request_id -> <rank_id, CommMessage> | |||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | |||
| std::mutex receive_messages_mutex_; | |||
| // request_id -> MessageCallback | |||
| std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | |||
| std::mutex message_callbacks_mutex_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -26,7 +26,7 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum NodeEvent { NODE_TIMEOUT = 0 }; | |||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; | |||
| struct NodeInfo { | |||
| NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | |||
| @@ -69,6 +69,10 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) { | |||
| << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; | |||
| } | |||
| void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); } | |||
| bool NodeManager::CheckNodesFinishState() { return heartbeats_finish_nodes_.size() == nodes_info_.size(); } | |||
| std::vector<ServersMeta> NodeManager::FetchServersMeta() { | |||
| std::vector<ServersMeta> servers_meta_list; | |||
| for (auto it = nodes_info_.begin(); it != nodes_info_.end(); ++it) { | |||
| @@ -131,7 +135,11 @@ bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); } | |||
| bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); } | |||
| bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_; } | |||
| bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_.load(); } | |||
| bool NodeManager::is_node_timeout() { return is_node_timeout_.load(); } | |||
| void NodeManager::set_cluster_timeout(bool is_cluster_timeout) { is_cluster_timeout_ = is_cluster_timeout; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef RPC_CLUSTER_MANAGER_H | |||
| #define RPC_CLUSTER_MANAGER_H | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_ | |||
| #include <atomic> | |||
| #include <cstdlib> | |||
| @@ -45,6 +45,7 @@ class NodeManager { | |||
| : is_cluster_ready_(false), | |||
| is_cluster_finish_(false), | |||
| is_cluster_timeout_(false), | |||
| is_node_timeout_(false), | |||
| total_node_num_(0), | |||
| next_worker_rank_id_(-1), | |||
| next_server_rank_id_(-1) {} | |||
| @@ -55,6 +56,8 @@ class NodeManager { | |||
| void InitNodeNum(); | |||
| int NextRankId(const RegisterMessage ®ister_message); | |||
| void UpdateHeartbeat(const std::string &node_id); | |||
| void UpdateNodeFinishState(const std::string &node_id); | |||
| bool CheckNodesFinishState(); | |||
| std::vector<ServersMeta> FetchServersMeta(); | |||
| void UpdateClusterState(); | |||
| void CheckClusterTimeout(); | |||
| @@ -63,11 +66,14 @@ class NodeManager { | |||
| bool is_cluster_ready(); | |||
| bool is_cluster_finish(); | |||
| bool is_cluster_timeout(); | |||
| bool is_node_timeout(); | |||
| void set_cluster_timeout(bool is_cluster_timeout); | |||
| private: | |||
| std::atomic<bool> is_cluster_ready_; | |||
| std::atomic<bool> is_cluster_finish_; | |||
| std::atomic<bool> is_cluster_timeout_; | |||
| std::atomic<bool> is_node_timeout_; | |||
| uint32_t total_node_num_; | |||
| std::atomic<int> next_worker_rank_id_; | |||
| std::atomic<int> next_server_rank_id_; | |||
| @@ -76,6 +82,7 @@ class NodeManager { | |||
| std::mutex assign_rank_id_mutex_; | |||
| std::mutex heartbeat_mutex_; | |||
| std::unordered_map<std::string, timeval> heartbeats_; | |||
| std::unordered_set<std::string> heartbeats_finish_nodes_; | |||
| // timeout nodes | |||
| std::unordered_map<std::string, NodeInfo> timeout_nodes_info_; | |||
| std::unordered_set<std::string> finish_nodes_id_; | |||
| @@ -83,4 +90,4 @@ class NodeManager { | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // RPC_CLUSTER_MANAGER_H | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_ | |||
| @@ -64,6 +64,7 @@ message RegisterRespMessage { | |||
| message HeartbeatMessage { | |||
| // the current Node unique id:0,1,2... | |||
| string node_id = 1; | |||
| bool is_node_finish = 2; | |||
| } | |||
| message HeartbeatRespMessage { | |||
| @@ -71,6 +72,7 @@ message HeartbeatRespMessage { | |||
| bool is_cluster_ready = 1; | |||
| bool is_cluster_finish = 2; | |||
| bool is_cluster_timeout = 3; | |||
| bool is_node_timeout = 4; | |||
| } | |||
| message FetchServersRespMessage { | |||
| @@ -0,0 +1,222 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/scheduler_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| SchedulerNode::~SchedulerNode() { | |||
| MS_LOG(INFO) << "Stop scheduler node!"; | |||
| if (!is_already_stopped_) { | |||
| is_already_stopped_ = true; | |||
| server_->Stop(); | |||
| if (scheduler_thread_->joinable()) { | |||
| scheduler_thread_->join(); | |||
| } | |||
| if (update_state_thread_->joinable()) { | |||
| update_state_thread_->join(); | |||
| } | |||
| is_ready_ = true; | |||
| } | |||
| } | |||
| bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Start scheduler node!"; | |||
| Initialize(); | |||
| StartUpdateClusterStateTimer(); | |||
| if (!WaitForStart(timeout)) { | |||
| MS_LOG(ERROR) << "Start Scheduler node timeout!"; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Start the scheduler node is successful!"; | |||
| return true; | |||
| } | |||
| void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.ParseFromString(message.data()); | |||
| node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); | |||
| if (heartbeat_message.is_node_finish()) { | |||
| node_manager_.UpdateNodeFinishState(heartbeat_message.node_id()); | |||
| } | |||
| if (heartbeat_message.is_node_finish() && node_manager_.CheckNodesFinishState()) { | |||
| MS_LOG(INFO) << "The scheduler node receive all the finish cmd!"; | |||
| is_finish_ = true; | |||
| wait_finish_cond_.notify_all(); | |||
| } | |||
| HeartbeatRespMessage heartbeat_resp_message; | |||
| heartbeat_resp_message.set_is_cluster_ready(node_manager_.is_cluster_ready()); | |||
| heartbeat_resp_message.set_is_cluster_finish(node_manager_.is_cluster_finish()); | |||
| heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); | |||
| heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| comm_message.set_data(heartbeat_resp_message.SerializeAsString()); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| } | |||
| void SchedulerNode::Initialize() { | |||
| CreateTcpServer(); | |||
| is_already_stopped_ = false; | |||
| node_info_.node_id_ = CommUtil::GenerateUUID(); | |||
| node_info_.node_role_ = NodeRole::SCHEDULER; | |||
| MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_; | |||
| } | |||
| void SchedulerNode::CreateTcpServer() { | |||
| node_manager_.InitNodeNum(); | |||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||
| uint32_t scheduler_port = ClusterConfig::scheduler_port(); | |||
| server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port); | |||
| server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::HEARTBEAT: | |||
| ProcessHeartbeat(server, conn, message); | |||
| break; | |||
| case NodeCommand::REGISTER: | |||
| ProcessRegister(server, conn, message); | |||
| break; | |||
| case NodeCommand::FINISH: | |||
| ProcessFinish(server, conn, message); | |||
| break; | |||
| case NodeCommand::FETCH_SERVER: | |||
| ProcessFetchServers(server, conn, message); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| }); | |||
| server_->Init(); | |||
| scheduler_thread_ = std::make_unique<std::thread>([&]() { | |||
| MS_LOG(INFO) << "The scheduler node start a tcp server!"; | |||
| server_->Start(); | |||
| }); | |||
| scheduler_thread_->detach(); | |||
| } | |||
| void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| MS_LOG(INFO) << "The scheduler process a register message!"; | |||
| RegisterMessage register_message; | |||
| register_message.ParseFromString(message.data()); | |||
| // assign worker node and server node rank id | |||
| int rank_id = node_manager_.NextRankId(register_message); | |||
| if (rank_id < 0) { | |||
| MS_LOG(EXCEPTION) << "The rank id is wrong!"; | |||
| } | |||
| const std::string &node_id = register_message.node_id(); | |||
| node_manager_.UpdateHeartbeat(node_id); | |||
| RegisterRespMessage register_resp_message; | |||
| register_resp_message.set_node_id(node_id); | |||
| register_resp_message.set_rank_id(rank_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| comm_message.set_data(register_resp_message.SerializeAsString()); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| } | |||
| void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| FinishMessage finish_message; | |||
| finish_message.ParseFromString(message.data()); | |||
| node_manager_.AddFinishNode(finish_message); | |||
| MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, message); | |||
| } | |||
| void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, | |||
| const CommMessage &message) { | |||
| FetchServersRespMessage fetch_servers_message; | |||
| std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); | |||
| *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| comm_message.set_data(fetch_servers_message.SerializeAsString()); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| } | |||
| void SchedulerNode::StartUpdateClusterStateTimer() { | |||
| MS_LOG(WARNING) << "The scheduler start a heartbeat timer!"; | |||
| update_state_thread_ = std::make_unique<std::thread>([&]() { | |||
| auto start_time = std::chrono::steady_clock::now(); | |||
| while (!is_finish_.load()) { | |||
| // 1. update cluster timeout | |||
| if (!node_manager_.is_cluster_ready() && (std::chrono::steady_clock::now() - start_time > | |||
| std::chrono::seconds(ClusterConfig::cluster_available_timeout()))) { | |||
| node_manager_.CheckClusterTimeout(); | |||
| } | |||
| // 2. update cluster state | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||
| node_manager_.UpdateClusterState(); | |||
| if (node_manager_.is_cluster_ready()) { | |||
| is_ready_ = true; | |||
| wait_start_cond_.notify_all(); | |||
| } | |||
| if (node_manager_.is_cluster_finish()) { | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval() * 2)); | |||
| is_finish_ = true; | |||
| wait_finish_cond_.notify_all(); | |||
| } | |||
| } | |||
| }); | |||
| update_state_thread_->detach(); | |||
| } | |||
| bool SchedulerNode::Stop() { | |||
| MS_LOG(INFO) << "Stop scheduler node!"; | |||
| if (!is_already_stopped_) { | |||
| is_already_stopped_ = true; | |||
| server_->Stop(); | |||
| if (scheduler_thread_->joinable()) { | |||
| scheduler_thread_->join(); | |||
| } | |||
| if (update_state_thread_->joinable()) { | |||
| update_state_thread_->join(); | |||
| } | |||
| is_ready_ = true; | |||
| } | |||
| return true; | |||
| } | |||
| bool SchedulerNode::Finish(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Finish scheduler node!"; | |||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | |||
| wait_finish_cond_.wait(lock, [&] { | |||
| if (is_finish_.load()) { | |||
| MS_LOG(INFO) << "The scheduler finish success!"; | |||
| } | |||
| return is_finish_.load(); | |||
| }); | |||
| return true; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_ | |||
| #include <atomic> | |||
| #include <cstdlib> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <thread> | |||
| #include <mutex> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| #include "ps/core/node_manager.h" | |||
| #include "ps/core/node.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class SchedulerNode : public Node { | |||
| public: | |||
| SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | |||
| ~SchedulerNode() override; | |||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| bool Stop() override; | |||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| private: | |||
| void Initialize(); | |||
| void CreateTcpServer(); | |||
| void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void StartUpdateClusterStateTimer(); | |||
| void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| std::unique_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> scheduler_thread_; | |||
| std::unique_ptr<std::thread> update_state_thread_; | |||
| NodeManager node_manager_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_ | |||
| @@ -38,7 +38,7 @@ bool ServerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Start server node!"; | |||
| Initialize(); | |||
| Register(client_to_scheduler_); | |||
| Heartbeat(client_to_scheduler_); | |||
| StartHeartbeatTimer(client_to_scheduler_); | |||
| if (!WaitForStart(timeout)) { | |||
| MS_LOG(EXCEPTION) << "Start Worker node timeout!"; | |||
| @@ -146,11 +146,7 @@ void TcpClient::StopEventBase() { | |||
| MS_LOG(INFO) << "Stop tcp client event base!"; | |||
| int ret = event_base_loopbreak(event_base_); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Event base loop break failed!"; | |||
| } | |||
| if (event_base_) { | |||
| event_base_free(event_base_); | |||
| event_base_ = nullptr; | |||
| MS_LOG(ERROR) << "Event base loop break failed!"; | |||
| } | |||
| } | |||
| @@ -44,7 +44,7 @@ bool WorkerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Starting worker node!"; | |||
| Initialize(); | |||
| Register(client_to_scheduler_); | |||
| Heartbeat(client_to_scheduler_); | |||
| StartHeartbeatTimer(client_to_scheduler_); | |||
| if (!WaitForStart(timeout)) { | |||
| MS_LOG(ERROR) << "Start Worker node timeout!"; | |||