| @@ -5,6 +5,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "util.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "embedding_table_shard_metadata.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc") | |||
| @@ -16,6 +17,8 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") | |||
| endif () | |||
| if (NOT ENABLE_D) | |||
| @@ -0,0 +1,212 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/abstract_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::REGISTER); | |||
| RegisterMessage register_message; | |||
| register_message.set_node_id(node_info_.node_id_); | |||
| register_message.set_role(node_info_.node_role_); | |||
| register_message.set_ip(node_info_.ip_); | |||
| register_message.set_port(node_info_.port_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(register_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, comm_message)) { | |||
| MS_LOG(EXCEPTION) << "Node register timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ | |||
| << "is registering to scheduler, the request id is:" << message_meta.request_id(); | |||
| } | |||
| void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | |||
| RegisterRespMessage register_resp_message; | |||
| register_resp_message.ParseFromString(message.data()); | |||
| if (register_resp_message.node_id() != node_info_.node_id_) { | |||
| MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() | |||
| << " is not match the current node id:" << node_info_.node_id_; | |||
| } | |||
| node_info_.rank_id_ = register_resp_message.rank_id(); | |||
| MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||
| } | |||
| bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id, timeout); | |||
| } | |||
| void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_message) { | |||
| on_node_event_message_ = on_node_event_message; | |||
| } | |||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) { | |||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||
| << " begin send heartbeat to the scheduler!"; | |||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | |||
| while (!is_finish_.load()) { | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(heartbeat_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| } | |||
| } | |||
| }); | |||
| heart_beat_thread_->detach(); | |||
| } | |||
| void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||
| HeartbeatRespMessage heartbeat_resp_message; | |||
| heartbeat_resp_message.ParseFromString(message.data()); | |||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||
| if (is_ready_.load()) { | |||
| wait_start_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; | |||
| } | |||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | |||
| if (is_finish_.load()) { | |||
| wait_finish_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; | |||
| } | |||
| is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | |||
| if (is_timeout_ && on_node_event_message_) { | |||
| is_ready_ = true; | |||
| wait_start_cond_.notify_all(); | |||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | |||
| } | |||
| } | |||
| void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FETCH_SERVER); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; | |||
| } | |||
| } | |||
| void AbstractNode::ProcessFetchServersResp(const CommMessage &message) { | |||
| FetchServersRespMessage fetch_servers_resp_message; | |||
| fetch_servers_resp_message.ParseFromString(message.data()); | |||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | |||
| nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); | |||
| } | |||
| MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); | |||
| } | |||
| bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FINISH); | |||
| FinishMessage finish_message; | |||
| finish_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(finish_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||
| return WaitForDisconnect(timeout); | |||
| } | |||
| bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | |||
| bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| if (is_finish_.load()) { | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||
| } | |||
| return is_finish_.load(); | |||
| }); | |||
| return res; | |||
| } | |||
| bool AbstractNode::InitClientToScheduler() { | |||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||
| uint16_t scheduler_port = ClusterConfig::scheduler_port(); | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); | |||
| client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::HEARTBEAT: | |||
| ProcessHeartbeatResp(message); | |||
| break; | |||
| case NodeCommand::REGISTER: | |||
| ProcessRegisterResp(message); | |||
| break; | |||
| case NodeCommand::FETCH_SERVER: | |||
| ProcessFetchServersResp(message); | |||
| break; | |||
| case NodeCommand::FINISH: | |||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| NotifyMessageArrival(message); | |||
| }); | |||
| client_to_scheduler_->Init(); | |||
| client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() { | |||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||
| client_to_scheduler_->Start(); | |||
| }); | |||
| client_to_scheduler_thread_->detach(); | |||
| client_to_scheduler_->set_disconnected_callback([&]() { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval())); | |||
| client_to_scheduler_->Stop(); | |||
| client_to_scheduler_->Init(); | |||
| }); | |||
| return client_to_scheduler_->WaitConnected(); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ps/core/node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class AbstractNode : public Node { | |||
| public: | |||
| AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} | |||
| ~AbstractNode() override = default; | |||
| bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void set_event_callback(const OnNodeEventMessage &on_node_event_message); | |||
| protected: | |||
| void Register(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessRegisterResp(const CommMessage &message); | |||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessHeartbeatResp(const CommMessage &message); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessFetchServersResp(const CommMessage &message); | |||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||
| bool WaitForDisconnect(const uint32_t &timeout); | |||
| bool InitClientToScheduler(); | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| std::unique_ptr<std::thread> client_to_scheduler_thread_; | |||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||
| OnNodeEventMessage on_node_event_message_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ | |||
| @@ -31,6 +31,8 @@ uint32_t ClusterConfig::heartbeat_interval_ = 3; | |||
| uint32_t ClusterConfig::heartbeat_timeout_ = 30; | |||
| // Timeout period for cluster preparation is 300 seconds. | |||
| uint32_t ClusterConfig::cluster_available_timeout_ = 300; | |||
| // The timeout period for the client to connect to the server is 100ms. | |||
| uint32_t ClusterConfig::connect_interval_ = 100; | |||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | |||
| std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | |||
| @@ -69,6 +71,9 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa | |||
| cluster_available_timeout_ = cluster_available_timeout; | |||
| } | |||
| uint32_t ClusterConfig::connect_interval() { return connect_interval_; } | |||
| void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -42,6 +42,8 @@ class ClusterConfig { | |||
| static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout); | |||
| static uint32_t cluster_available_timeout(); | |||
| static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); | |||
| static uint32_t connect_interval(); | |||
| static void set_connect_interval(const uint32_t &connect_interval); | |||
| private: | |||
| static uint32_t worker_num_; | |||
| @@ -51,6 +53,7 @@ class ClusterConfig { | |||
| static uint16_t scheduler_port_; | |||
| static uint32_t heartbeat_timeout_; | |||
| static uint32_t cluster_available_timeout_; | |||
| static uint32_t connect_interval_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -19,78 +19,11 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void Node::Heartbeat(const std::shared_ptr<TcpClient> &client) { | |||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||
| << " begin send heartbeat to the scheduler!"; | |||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | |||
| while (!is_finish_.load()) { | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(heartbeat_message.SerializeAsString()); | |||
| SendMessageAsync(client, message); | |||
| } | |||
| }); | |||
| heart_beat_thread_->detach(); | |||
| } | |||
| void Node::ProcessHeartbeatResp(const CommMessage &message) { | |||
| HeartbeatRespMessage heartbeat_resp_message; | |||
| heartbeat_resp_message.ParseFromString(message.data()); | |||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||
| if (is_ready_.load()) { | |||
| wait_start_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; | |||
| } | |||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | |||
| if (is_finish_.load()) { | |||
| wait_finish_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; | |||
| } | |||
| is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | |||
| if (is_timeout_ && on_node_event_message_) { | |||
| is_ready_ = true; | |||
| wait_start_cond_.notify_all(); | |||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | |||
| } | |||
| } | |||
| void Node::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FETCH_SERVER); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; | |||
| } | |||
| } | |||
| void Node::ProcessFetchServersResp(const CommMessage &message) { | |||
| FetchServersRespMessage fetch_servers_resp_message; | |||
| fetch_servers_resp_message.ParseFromString(message.data()); | |||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | |||
| nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); | |||
| } | |||
| MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); | |||
| } | |||
| std::string Node::node_id() const { return node_info_.node_id_; } | |||
| uint32_t Node::rank_id() const { return node_info_.rank_id_; } | |||
| void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { | |||
| on_node_event_message_ = on_node_event_message; | |||
| } | |||
| NodeRole Node::role() const { return node_info_.node_role_; } | |||
| bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(message_tracker_mutex_); | |||
| @@ -147,6 +80,7 @@ bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids | |||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| @@ -156,7 +90,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| comm_message_resp = &res[rank_id]; | |||
| *comm_message_resp = res[rank_id]; | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| @@ -164,6 +98,8 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| @@ -175,6 +111,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s | |||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| @@ -213,23 +150,6 @@ bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool Node::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FINISH); | |||
| FinishMessage finish_message; | |||
| finish_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(finish_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||
| return WaitForDisconnect(timeout); | |||
| } | |||
| bool Node::WaitForStart(const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(wait_start_mutex_); | |||
| bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| @@ -242,17 +162,6 @@ bool Node::WaitForStart(const uint32_t &timeout) { | |||
| return res; | |||
| } | |||
| bool Node::WaitForDisconnect(const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | |||
| bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| if (is_finish_.load()) { | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||
| } | |||
| return is_finish_.load(); | |||
| }); | |||
| return res; | |||
| } | |||
| bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| @@ -268,15 +177,6 @@ void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const Comm | |||
| client->SendMessage(message); | |||
| } | |||
| void Node::NotifyMessageArrival(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| uint64_t request_id = message_meta.request_id(); | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) { | |||
| std::lock_guard<std::mutex> lock(client_mutex_); | |||
| if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { | |||
| @@ -292,6 +192,7 @@ const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendDataResp(message); | |||
| RunMessageCallback(message.pb_meta().request_id()); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| @@ -317,13 +218,13 @@ void Node::ProcessSendDataResp(const CommMessage &message) { | |||
| res.insert(std::make_pair(rank_id, message)); | |||
| receive_messages_[request_id] = res; | |||
| } | |||
| RunMessageCallback(request_id); | |||
| } | |||
| void Node::RunMessageCallback(const uint64_t &request_id) { | |||
| message_callbacks_mutex_.lock(); | |||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) { | |||
| // When receiving a message's response, Then compare with the desired number of responses, | |||
| // If they are equal, then call the callback function | |||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { | |||
| auto it = message_callbacks_.find(request_id); | |||
| if (it != message_callbacks_.end()) { | |||
| message_callbacks_mutex_.unlock(); | |||
| @@ -346,6 +247,15 @@ void Node::set_message_callback(const uint64_t &request_id, const MessageCallbac | |||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | |||
| message_callbacks_[request_id] = message_callback; | |||
| } | |||
| void Node::NotifyMessageArrival(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| uint64_t request_id = message_meta.request_id(); | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -52,8 +52,7 @@ class Node { | |||
| is_timeout_(false), | |||
| is_already_stopped_(true), | |||
| is_already_finished_(false), | |||
| next_request_id_(0), | |||
| heart_beat_thread_(nullptr) {} | |||
| next_request_id_(0) {} | |||
| virtual ~Node() = default; | |||
| using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | |||
| @@ -63,9 +62,10 @@ class Node { | |||
| virtual bool Stop() = 0; | |||
| virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; | |||
| void set_callback(const OnNodeEventMessage &on_node_event_message); | |||
| std::string node_id() const; | |||
| uint32_t rank_id() const; | |||
| NodeRole role() const; | |||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| @@ -73,27 +73,21 @@ class Node { | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| protected: | |||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessHeartbeatResp(const CommMessage &message); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessFetchServersResp(const CommMessage &message); | |||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||
| bool WaitForStart(const uint32_t &timeout); | |||
| bool WaitForDisconnect(const uint32_t &timeout); | |||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | |||
| void ProcessSendDataResp(const CommMessage &message); | |||
| void RunMessageCallback(const uint64_t &request_id); | |||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| NodeInfo node_info_; | |||
| std::atomic<bool> is_ready_; | |||
| @@ -102,9 +96,6 @@ class Node { | |||
| std::atomic<bool> is_already_stopped_; | |||
| std::atomic<bool> is_already_finished_; | |||
| std::atomic_uint64_t next_request_id_; | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| OnNodeEventMessage on_node_event_message_; | |||
| // <NodeRole,rank_id>-><ip, port> | |||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | |||
| @@ -132,5 +123,4 @@ class Node { | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ | |||
| @@ -0,0 +1,145 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/server_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| ServerNode::~ServerNode() { | |||
| MS_LOG(INFO) << "Stop server node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| server_->Stop(); | |||
| client_to_scheduler_->Stop(); | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (server_thread_->joinable()) { | |||
| server_thread_->join(); | |||
| } | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| } | |||
| } | |||
| bool ServerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Start server node!"; | |||
| Initialize(); | |||
| Register(client_to_scheduler_); | |||
| Heartbeat(client_to_scheduler_); | |||
| if (!WaitForStart(timeout)) { | |||
| MS_LOG(EXCEPTION) << "Start Worker node timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The cluster is ready to use!"; | |||
| // If the cluster is ready to use, then Get the address of all the servers | |||
| if (!is_timeout_.load()) { | |||
| FetchServers(client_to_scheduler_); | |||
| MS_LOG(INFO) << "Server node get all the servers address successful!"; | |||
| } | |||
| MS_LOG(INFO) << "Start the node is successful!"; | |||
| return true; | |||
| } | |||
| void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } | |||
| void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||
| const std::string &message) { | |||
| auto &meta = const_cast<MessageMeta &>(message_meta); | |||
| meta.set_role(node_info_.node_role_); | |||
| meta.set_rank_id(node_info_.rank_id_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {meta}; | |||
| comm_message.set_data(message); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| } | |||
| void ServerNode::CreateTcpServer() { | |||
| std::string interface; | |||
| std::string server_ip; | |||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); | |||
| server_ = std::make_shared<TcpServer>(server_ip, 0); | |||
| server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendData(server, conn, message); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| }); | |||
| server_->Init(); | |||
| server_thread_ = std::make_unique<std::thread>([&]() { | |||
| MS_LOG(INFO) << "The server node start a tcp server!"; | |||
| server_->Start(); | |||
| }); | |||
| server_thread_->detach(); | |||
| } | |||
| void ServerNode::Initialize() { | |||
| CreateTcpServer(); | |||
| is_already_stopped_ = false; | |||
| node_info_.node_id_ = CommUtil::GenerateUUID(); | |||
| node_info_.node_role_ = NodeRole::SERVER; | |||
| node_info_.ip_ = server_->BoundIp(); | |||
| node_info_.port_ = server_->BoundPort(); | |||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " is generate uuid is:" << node_info_.node_id_; | |||
| if (!InitClientToScheduler()) { | |||
| MS_LOG(EXCEPTION) << "Server node init client timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "Server node init client successful!"; | |||
| } | |||
| void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| if (request_handler_) { | |||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||
| } | |||
| } | |||
| bool ServerNode::Stop() { | |||
| MS_LOG(INFO) << "Stop server node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| server_->Stop(); | |||
| client_to_scheduler_->Stop(); | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (server_thread_->joinable()) { | |||
| server_thread_->join(); | |||
| } | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| } | |||
| return true; | |||
| } | |||
| bool ServerNode::Finish(const uint32_t &timeout) { | |||
| std::lock_guard<std::mutex> lock(finish_mutex_); | |||
| if (is_already_finished_) { | |||
| MS_LOG(INFO) << "Server node already finish!"; | |||
| return true; | |||
| } | |||
| is_already_finished_ = true; | |||
| return Disconnect(client_to_scheduler_, timeout); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ | |||
| #include <cstdlib> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <utility> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| #include "ps/core/abstract_node.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class ServerNode : public AbstractNode { | |||
| public: | |||
| ServerNode() : server_(nullptr), server_thread_(nullptr) {} | |||
| ~ServerNode() override; | |||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| bool Stop() override; | |||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, | |||
| const MessageMeta message_meta, const std::string &message)>; | |||
| void set_handler(const RequestHandler &handler); | |||
| void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||
| const std::string &message); | |||
| private: | |||
| void CreateTcpServer(); | |||
| void Initialize(); | |||
| void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> server_thread_; | |||
| RequestHandler request_handler_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ | |||
| @@ -35,7 +35,6 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| event_base *TcpClient::event_base_ = nullptr; | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||
| @@ -43,7 +42,8 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||
| buffer_event_(nullptr), | |||
| server_address_(std::move(address)), | |||
| server_port_(port), | |||
| is_stop_(true) { | |||
| is_stop_(true), | |||
| is_connected_(false) { | |||
| message_handler_.SetCallback([this](const CommMessage &message) { | |||
| if (message_callback_) { | |||
| message_callback_(*this, message); | |||
| @@ -55,12 +55,15 @@ TcpClient::~TcpClient() { Stop(); } | |||
| std::string TcpClient::GetServerAddress() const { return server_address_; } | |||
| void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, | |||
| const OnTimeout &timeout) { | |||
| connected_callback_ = conn; | |||
| disconnected_callback_ = disconn; | |||
| read_callback_ = read; | |||
| timeout_callback_ = timeout; | |||
| void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; } | |||
| void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; } | |||
| bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { | |||
| std::unique_lock<std::mutex> lock(connection_mutex_); | |||
| bool res = | |||
| connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); }); | |||
| return res; | |||
| } | |||
| void TcpClient::Init() { | |||
| @@ -68,6 +71,7 @@ void TcpClient::Init() { | |||
| if (buffer_event_) { | |||
| return; | |||
| } | |||
| is_stop_ = false; | |||
| if (!CommUtil::CheckIp(server_address_)) { | |||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| @@ -198,6 +202,12 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { | |||
| } | |||
| } | |||
| void TcpClient::NotifyConnected() { | |||
| MS_LOG(INFO) << "Client connected to the server!"; | |||
| is_connected_ = true; | |||
| connection_cond_.notify_all(); | |||
| } | |||
| void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| @@ -205,27 +215,24 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||
| if (events & BEV_EVENT_CONNECTED) { | |||
| // Connected | |||
| if (tcp_client->connected_callback_) { | |||
| tcp_client->connected_callback_(*tcp_client); | |||
| tcp_client->connected_callback_(); | |||
| } | |||
| evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev)); | |||
| tcp_client->NotifyConnected(); | |||
| evutil_socket_t fd = bufferevent_getfd(bev); | |||
| SetTcpNoDelay(fd); | |||
| MS_LOG(INFO) << "Client connected!"; | |||
| } else if (events & BEV_EVENT_ERROR) { | |||
| MS_LOG(ERROR) << "Client connected error!"; | |||
| if (tcp_client->disconnected_callback_) { | |||
| tcp_client->disconnected_callback_(*tcp_client, errno); | |||
| tcp_client->disconnected_callback_(); | |||
| } | |||
| } else if (events & BEV_EVENT_EOF) { | |||
| MS_LOG(ERROR) << "Client connected end of file"; | |||
| if (tcp_client->disconnected_callback_) { | |||
| tcp_client->disconnected_callback_(*tcp_client, 0); | |||
| } | |||
| } | |||
| } | |||
| void TcpClient::Start() { | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| is_stop_ = false; | |||
| int ret = event_base_dispatch(event_base_); | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | |||
| @@ -30,19 +30,19 @@ | |||
| #include <thread> | |||
| #include <mutex> | |||
| #include <atomic> | |||
| #include <condition_variable> | |||
| #include "ps/core/cluster_config.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class TcpClient { | |||
| public: | |||
| using OnConnected = std::function<void(const TcpClient &)>; | |||
| using OnDisconnected = std::function<void(const TcpClient &, int)>; | |||
| using OnConnected = std::function<void()>; | |||
| using OnDisconnected = std::function<void()>; | |||
| using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | |||
| using OnTimeout = std::function<void(const TcpClient &)>; | |||
| using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; | |||
| @@ -52,8 +52,9 @@ class TcpClient { | |||
| virtual ~TcpClient(); | |||
| std::string GetServerAddress() const; | |||
| void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, | |||
| const OnTimeout &timeout); | |||
| void set_disconnected_callback(const OnDisconnected &disconnected); | |||
| void set_connected_callback(const OnConnected &connected); | |||
| bool WaitConnected(const uint32_t &connected_timeout = ClusterConfig::cluster_available_timeout()); | |||
| void Init(); | |||
| void StartWithDelay(int seconds); | |||
| void Stop(); | |||
| @@ -73,6 +74,7 @@ class TcpClient { | |||
| static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | |||
| virtual void OnReadHandler(const void *buf, size_t num); | |||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | |||
| void NotifyConnected(); | |||
| private: | |||
| OnMessage message_callback_; | |||
| @@ -86,12 +88,14 @@ class TcpClient { | |||
| static event_base *event_base_; | |||
| std::mutex connection_mutex_; | |||
| std::condition_variable connection_cond_; | |||
| event *event_timeout_; | |||
| bufferevent *buffer_event_; | |||
| std::string server_address_; | |||
| std::uint16_t server_port_; | |||
| std::atomic<bool> is_stop_; | |||
| std::atomic<bool> is_connected_; | |||
| }; | |||
| } // namespace core | |||
| @@ -95,6 +95,7 @@ void TcpServer::Init() { | |||
| MS_LOG(EXCEPTION) << "Use event pthread failed!"; | |||
| } | |||
| is_stop_ = false; | |||
| base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| if (!CommUtil::CheckIp(server_address_)) { | |||
| @@ -138,7 +139,6 @@ void TcpServer::Start() { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Start tcp server!"; | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| is_stop_ = false; | |||
| int ret = event_base_dispatch(base_); | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | |||
| @@ -368,6 +368,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); } | |||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | |||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -31,8 +31,8 @@ WorkerNode::~WorkerNode() { | |||
| } | |||
| } | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (worker_thread_->joinable()) { | |||
| worker_thread_->join(); | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| @@ -43,7 +43,7 @@ WorkerNode::~WorkerNode() { | |||
| bool WorkerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Starting worker node!"; | |||
| Initialize(); | |||
| Register(); | |||
| Register(client_to_scheduler_); | |||
| Heartbeat(client_to_scheduler_); | |||
| if (!WaitForStart(timeout)) { | |||
| @@ -52,84 +52,25 @@ bool WorkerNode::Start(const uint32_t &timeout) { | |||
| } | |||
| MS_LOG(INFO) << "The node is ready to fetch servers!"; | |||
| // If the cluster is ready to use, then Get the address of all the servers | |||
| if (!is_timeout_.load()) { | |||
| FetchServers(client_to_scheduler_); | |||
| MS_LOG(INFO) << "Fetch servers successful!"; | |||
| MS_LOG(INFO) << "Worker node get all the servers address successful!"; | |||
| } | |||
| MS_LOG(INFO) << "The Worker node has successfully started."; | |||
| return true; | |||
| } | |||
| void WorkerNode::Register() { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::REGISTER); | |||
| RegisterMessage register_message; | |||
| register_message.set_node_id(node_info_.node_id_); | |||
| register_message.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(register_message.SerializeAsString()); | |||
| if (!SendMessageSync(client_to_scheduler_, comm_message)) { | |||
| MS_LOG(EXCEPTION) << "Worker node register timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_ | |||
| << "is registering to scheduler, the request id is:" << message_meta.request_id(); | |||
| } | |||
| void WorkerNode::ProcessRegisterResp(const CommMessage &message) { | |||
| RegisterRespMessage register_resp_message; | |||
| register_resp_message.ParseFromString(message.data()); | |||
| if (register_resp_message.node_id() != node_info_.node_id_) { | |||
| MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() | |||
| << " is not match the current node id:" << node_info_.node_id_; | |||
| } | |||
| node_info_.rank_id_ = register_resp_message.rank_id(); | |||
| MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||
| } | |||
| void WorkerNode::Initialize() { | |||
| is_already_stopped_ = false; | |||
| node_info_.node_id_ = CommUtil::GenerateUUID(); | |||
| node_info_.node_role_ = NodeRole::WORKER; | |||
| MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_; | |||
| InitClientToScheduler(); | |||
| } | |||
| void WorkerNode::InitClientToScheduler() { | |||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||
| uint16_t scheduler_port = ClusterConfig::scheduler_port(); | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); | |||
| client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::HEARTBEAT: | |||
| ProcessHeartbeatResp(message); | |||
| break; | |||
| case NodeCommand::REGISTER: | |||
| ProcessRegisterResp(message); | |||
| break; | |||
| case NodeCommand::FETCH_SERVER: | |||
| ProcessFetchServersResp(message); | |||
| break; | |||
| case NodeCommand::FINISH: | |||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| NotifyMessageArrival(message); | |||
| }); | |||
| client_to_scheduler_->Init(); | |||
| worker_thread_ = std::make_unique<std::thread>([&]() { | |||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||
| client_to_scheduler_->Start(); | |||
| }); | |||
| worker_thread_->detach(); | |||
| if (!InitClientToScheduler()) { | |||
| MS_LOG(EXCEPTION) << "Worker node init client timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "Worker node init client successful!"; | |||
| } | |||
| bool WorkerNode::Stop() { | |||
| @@ -144,8 +85,8 @@ bool WorkerNode::Stop() { | |||
| } | |||
| } | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (worker_thread_->joinable()) { | |||
| worker_thread_->join(); | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| @@ -165,23 +106,6 @@ bool WorkerNode::Finish(const uint32_t &timeout) { | |||
| is_already_finished_ = true; | |||
| return Disconnect(client_to_scheduler_, timeout); | |||
| } | |||
| bool WorkerNode::BroadcastToServers(const std::string &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -17,50 +17,35 @@ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||
| #include <atomic> | |||
| #include <cstdlib> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <condition_variable> | |||
| #include <algorithm> | |||
| #include <tuple> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| #include "ps/core/node.h" | |||
| #include "ps/core/abstract_node.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class WorkerNode : public Node { | |||
| class WorkerNode : public AbstractNode { | |||
| public: | |||
| WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {} | |||
| WorkerNode() = default; | |||
| ~WorkerNode() override; | |||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| bool Stop() override; | |||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| bool BroadcastToServers(const std::string &message); | |||
| private: | |||
| void Register(); | |||
| void ProcessRegisterResp(const CommMessage &message); | |||
| void Initialize(); | |||
| void InitClientToScheduler(); | |||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||
| std::unique_ptr<std::thread> worker_thread_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/embedding_table_shard_metadata.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| uint64_t EmbeddingTableShardMetadata::begin() const { return begin_; } | |||
| uint64_t EmbeddingTableShardMetadata::end() const { return end_; } | |||
| uint64_t EmbeddingTableShardMetadata::size() const { return end_ - begin_; } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ | |||
| #define MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ | |||
| #include <iostream> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| class EmbeddingTableShardMetadata { | |||
| public: | |||
| explicit EmbeddingTableShardMetadata(uint64_t begin, uint64_t end) : begin_(begin), end_(end) {} | |||
| virtual ~EmbeddingTableShardMetadata() = default; | |||
| uint64_t begin() const; | |||
| uint64_t end() const; | |||
| uint64_t size() const; | |||
| private: | |||
| uint64_t begin_; | |||
| uint64_t end_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common_test.h" | |||
| #include "ps/embedding_table_shard_metadata.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| class TestEmbeddingTableShardMetadata : public UT::Common { | |||
| public: | |||
| TestEmbeddingTableShardMetadata() = default; | |||
| virtual ~TestEmbeddingTableShardMetadata() = default; | |||
| void SetUp() override {} | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestEmbeddingTableShardMetadata, EmbeddingTable) { | |||
| EmbeddingTableShardMetadata embedding_table_shard(1, 100); | |||
| EXPECT_EQ(embedding_table_shard.begin(), 1); | |||
| EXPECT_EQ(embedding_table_shard.end(), 100); | |||
| EXPECT_EQ(embedding_table_shard.size(), 99); | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||