From 2d2bf2d0eefe76c4b920acd6591172e8514bb733 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Fri, 11 Dec 2020 17:14:32 +0800 Subject: [PATCH] added server node --- mindspore/ccsrc/ps/CMakeLists.txt | 3 + mindspore/ccsrc/ps/core/abstract_node.cc | 212 ++++++++++++++++++ mindspore/ccsrc/ps/core/abstract_node.h | 56 +++++ mindspore/ccsrc/ps/core/cluster_config.cc | 5 + mindspore/ccsrc/ps/core/cluster_config.h | 3 + mindspore/ccsrc/ps/core/node.cc | 128 ++--------- mindspore/ccsrc/ps/core/node.h | 20 +- mindspore/ccsrc/ps/core/server_node.cc | 145 ++++++++++++ mindspore/ccsrc/ps/core/server_node.h | 67 ++++++ mindspore/ccsrc/ps/core/tcp_client.cc | 37 +-- mindspore/ccsrc/ps/core/tcp_client.h | 16 +- mindspore/ccsrc/ps/core/tcp_server.cc | 3 +- mindspore/ccsrc/ps/core/worker_node.cc | 98 +------- mindspore/ccsrc/ps/core/worker_node.h | 21 +- .../ps/embedding_table_shard_metadata.cc | 27 +++ .../ccsrc/ps/embedding_table_shard_metadata.h | 40 ++++ .../ps/embedding_table_shard_metadata_test.cc | 38 ++++ 17 files changed, 668 insertions(+), 251 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/abstract_node.cc create mode 100644 mindspore/ccsrc/ps/core/abstract_node.h create mode 100644 mindspore/ccsrc/ps/core/server_node.cc create mode 100644 mindspore/ccsrc/ps/core/server_node.h create mode 100644 mindspore/ccsrc/ps/embedding_table_shard_metadata.cc create mode 100644 mindspore/ccsrc/ps/embedding_table_shard_metadata.h create mode 100644 tests/ut/cpp/ps/embedding_table_shard_metadata_test.cc diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 318e99635d..6d05481207 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -5,6 +5,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") list(REMOVE_ITEM _PS_SRC_FILES "util.cc") + list(REMOVE_ITEM _PS_SRC_FILES "embedding_table_shard_metadata.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc") @@ -16,6 +17,8 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") endif () if (NOT ENABLE_D) diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc new file mode 100644 index 0000000000..b1eaf9b4d0 --- /dev/null +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/abstract_node.h" + +namespace mindspore { +namespace ps { +namespace core { +void AbstractNode::Register(const std::shared_ptr &client) { + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::REGISTER); + + RegisterMessage register_message; + register_message.set_node_id(node_info_.node_id_); + register_message.set_role(node_info_.node_role_); + register_message.set_ip(node_info_.ip_); + register_message.set_port(node_info_.port_); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(register_message.SerializeAsString()); + if (!SendMessageSync(client, comm_message)) { + MS_LOG(EXCEPTION) << "Node register timeout!"; + } + + MS_LOG(INFO) << "The node id:" << node_info_.node_id_ + << "is registering to scheduler, the request id is:" << message_meta.request_id(); +} + +void AbstractNode::ProcessRegisterResp(const CommMessage &message) { + RegisterRespMessage register_resp_message; + register_resp_message.ParseFromString(message.data()); + if (register_resp_message.node_id() != node_info_.node_id_) { + MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() + << " is not match the current node id:" << node_info_.node_id_; + } + + node_info_.rank_id_ = register_resp_message.rank_id(); + + MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; +} + +bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) { + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); + for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(message); + auto client = GetOrCreateTcpClient((*it).first.second); + client->SendMessage(comm_message); + } + return Wait(request_id, timeout); +} + +void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_message) { + on_node_event_message_ = on_node_event_message; +} + +void AbstractNode::Heartbeat(const std::shared_ptr &client) { + MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ + << " begin send heartbeat to the scheduler!"; + heart_beat_thread_ = std::make_unique([&]() { + while (!is_finish_.load()) { + std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); + MessageMeta meta; + meta.set_cmd(NodeCommand::HEARTBEAT); + + HeartbeatMessage heartbeat_message; + heartbeat_message.set_node_id(node_info_.node_id_); + + CommMessage message; + *message.mutable_pb_meta() = {meta}; + message.set_data(heartbeat_message.SerializeAsString()); + if (!SendMessageSync(client, message)) { + MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; + } + } + }); + heart_beat_thread_->detach(); +} + +void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { + HeartbeatRespMessage heartbeat_resp_message; + heartbeat_resp_message.ParseFromString(message.data()); + is_ready_ = heartbeat_resp_message.is_cluster_ready(); + if (is_ready_.load()) { + wait_start_cond_.notify_all(); + MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; + } + is_finish_ = heartbeat_resp_message.is_cluster_finish(); + if (is_finish_.load()) { + wait_finish_cond_.notify_all(); + MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; + } + is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); + if (is_timeout_ && on_node_event_message_) { + is_ready_ = true; + wait_start_cond_.notify_all(); + on_node_event_message_(NodeEvent::NODE_TIMEOUT); + } +} + +void AbstractNode::FetchServers(const std::shared_ptr &client) { + MessageMeta meta; + meta.set_cmd(NodeCommand::FETCH_SERVER); + + CommMessage message; + *message.mutable_pb_meta() = {meta}; + if (!SendMessageSync(client, message)) { + MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; + } +} + +void AbstractNode::ProcessFetchServersResp(const CommMessage &message) { + FetchServersRespMessage fetch_servers_resp_message; + fetch_servers_resp_message.ParseFromString(message.data()); + + for (const auto &it : fetch_servers_resp_message.servers_meta()) { + nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); + } + + MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); +} + +bool AbstractNode::Disconnect(const std::shared_ptr &client, const uint32_t &timeout) { + MessageMeta meta; + meta.set_cmd(NodeCommand::FINISH); + + FinishMessage finish_message; + finish_message.set_node_id(node_info_.node_id_); + + CommMessage message; + *message.mutable_pb_meta() = {meta}; + message.set_data(finish_message.SerializeAsString()); + if (!SendMessageSync(client, message)) { + MS_LOG(EXCEPTION) << "Disconnect timeout!"; + } + MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; + return WaitForDisconnect(timeout); +} + +bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { + std::unique_lock lock(wait_finish_mutex_); + bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { + if (is_finish_.load()) { + MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; + } + return is_finish_.load(); + }); + return res; +} + +bool AbstractNode::InitClientToScheduler() { + std::string scheduler_host = ClusterConfig::scheduler_host(); + uint16_t scheduler_port = ClusterConfig::scheduler_port(); + client_to_scheduler_ = std::make_shared(scheduler_host, scheduler_port); + client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { + switch (message.pb_meta().cmd()) { + case NodeCommand::HEARTBEAT: + ProcessHeartbeatResp(message); + break; + case NodeCommand::REGISTER: + ProcessRegisterResp(message); + break; + case NodeCommand::FETCH_SERVER: + ProcessFetchServersResp(message); + break; + case NodeCommand::FINISH: + MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; + break; + default: + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + NotifyMessageArrival(message); + }); + + client_to_scheduler_->Init(); + client_to_scheduler_thread_ = std::make_unique([&]() { + MS_LOG(INFO) << "The worker node start a tcp client!"; + client_to_scheduler_->Start(); + }); + client_to_scheduler_thread_->detach(); + + client_to_scheduler_->set_disconnected_callback([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval())); + client_to_scheduler_->Stop(); + client_to_scheduler_->Init(); + }); + return client_to_scheduler_->WaitConnected(); +} +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h new file mode 100644 index 0000000000..e1fe6a3d7f --- /dev/null +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ +#define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ + +#include +#include +#include + +#include "ps/core/node.h" + +namespace mindspore { +namespace ps { +namespace core { +class AbstractNode : public Node { + public: + AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} + ~AbstractNode() override = default; + + bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); + void set_event_callback(const OnNodeEventMessage &on_node_event_message); + + protected: + void Register(const std::shared_ptr &client); + void ProcessRegisterResp(const CommMessage &message); + void Heartbeat(const std::shared_ptr &client); + void ProcessHeartbeatResp(const CommMessage &message); + void FetchServers(const std::shared_ptr &client); + void ProcessFetchServersResp(const CommMessage &message); + bool Disconnect(const std::shared_ptr &client, const uint32_t &timeout); + bool WaitForDisconnect(const uint32_t &timeout); + bool InitClientToScheduler(); + + std::unique_ptr heart_beat_thread_; + std::unique_ptr client_to_scheduler_thread_; + std::shared_ptr client_to_scheduler_; + OnNodeEventMessage on_node_event_message_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ diff --git a/mindspore/ccsrc/ps/core/cluster_config.cc b/mindspore/ccsrc/ps/core/cluster_config.cc index 2fb9052cd9..23f1635da7 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.cc +++ b/mindspore/ccsrc/ps/core/cluster_config.cc @@ -31,6 +31,8 @@ uint32_t ClusterConfig::heartbeat_interval_ = 3; uint32_t ClusterConfig::heartbeat_timeout_ = 30; // Timeout period for cluster preparation is 300 seconds. uint32_t ClusterConfig::cluster_available_timeout_ = 300; +// The timeout period for the client to connect to the server is 100ms. +uint32_t ClusterConfig::connect_interval_ = 100; void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr scheduler_host, const uint16_t &scheduler_port) { @@ -69,6 +71,9 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa cluster_available_timeout_ = cluster_available_timeout; } +uint32_t ClusterConfig::connect_interval() { return connect_interval_; } + +void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/cluster_config.h b/mindspore/ccsrc/ps/core/cluster_config.h index 7cf379c684..20104949a7 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.h +++ b/mindspore/ccsrc/ps/core/cluster_config.h @@ -42,6 +42,8 @@ class ClusterConfig { static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout); static uint32_t cluster_available_timeout(); static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); + static uint32_t connect_interval(); + static void set_connect_interval(const uint32_t &connect_interval); private: static uint32_t worker_num_; @@ -51,6 +53,7 @@ class ClusterConfig { static uint16_t scheduler_port_; static uint32_t heartbeat_timeout_; static uint32_t cluster_available_timeout_; + static uint32_t connect_interval_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/node.cc b/mindspore/ccsrc/ps/core/node.cc index 500afb11d2..2ae02f6a39 100644 --- a/mindspore/ccsrc/ps/core/node.cc +++ b/mindspore/ccsrc/ps/core/node.cc @@ -19,78 +19,11 @@ namespace mindspore { namespace ps { namespace core { -void Node::Heartbeat(const std::shared_ptr &client) { - MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) - << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ - << " begin send heartbeat to the scheduler!"; - heart_beat_thread_ = std::make_unique([&]() { - while (!is_finish_.load()) { - std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); - MessageMeta meta; - meta.set_cmd(NodeCommand::HEARTBEAT); - - HeartbeatMessage heartbeat_message; - heartbeat_message.set_node_id(node_info_.node_id_); - - CommMessage message; - *message.mutable_pb_meta() = {meta}; - message.set_data(heartbeat_message.SerializeAsString()); - SendMessageAsync(client, message); - } - }); - heart_beat_thread_->detach(); -} - -void Node::ProcessHeartbeatResp(const CommMessage &message) { - HeartbeatRespMessage heartbeat_resp_message; - heartbeat_resp_message.ParseFromString(message.data()); - is_ready_ = heartbeat_resp_message.is_cluster_ready(); - if (is_ready_.load()) { - wait_start_cond_.notify_all(); - MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; - } - is_finish_ = heartbeat_resp_message.is_cluster_finish(); - if (is_finish_.load()) { - wait_finish_cond_.notify_all(); - MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; - } - is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); - if (is_timeout_ && on_node_event_message_) { - is_ready_ = true; - wait_start_cond_.notify_all(); - on_node_event_message_(NodeEvent::NODE_TIMEOUT); - } -} - -void Node::FetchServers(const std::shared_ptr &client) { - MessageMeta meta; - meta.set_cmd(NodeCommand::FETCH_SERVER); - - CommMessage message; - *message.mutable_pb_meta() = {meta}; - if (!SendMessageSync(client, message)) { - MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; - } -} - -void Node::ProcessFetchServersResp(const CommMessage &message) { - FetchServersRespMessage fetch_servers_resp_message; - fetch_servers_resp_message.ParseFromString(message.data()); - - for (const auto &it : fetch_servers_resp_message.servers_meta()) { - nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); - } - - MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); -} - std::string Node::node_id() const { return node_info_.node_id_; } uint32_t Node::rank_id() const { return node_info_.rank_id_; } -void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { - on_node_event_message_ = on_node_event_message; -} +NodeRole Node::role() const { return node_info_.node_role_; } bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { std::unique_lock lock(message_tracker_mutex_); @@ -147,6 +80,7 @@ bool Node::Send(const NodeRole &node_role, const std::vector &rank_ids bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *comm_message_resp, const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(comm_message_resp); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } @@ -156,7 +90,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; - comm_message_resp = &res[rank_id]; + *comm_message_resp = res[rank_id]; receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -164,6 +98,8 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s MessageMeta message_meta; message_meta.set_cmd(NodeCommand::SEND_DATA); message_meta.set_request_id(request_id); + message_meta.set_rank_id(node_info_.rank_id_); + message_meta.set_role(node_info_.node_role_); CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; @@ -175,6 +111,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s bool Node::Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, std::vector *comm_message_resp, const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(comm_message_resp); uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(data.size(), 0); @@ -213,23 +150,6 @@ bool Node::Send(const NodeRole &node_role, const std::vector &rank_ids return Wait(request_id, timeout); } -bool Node::Disconnect(const std::shared_ptr &client, const uint32_t &timeout) { - MessageMeta meta; - meta.set_cmd(NodeCommand::FINISH); - - FinishMessage finish_message; - finish_message.set_node_id(node_info_.node_id_); - - CommMessage message; - *message.mutable_pb_meta() = {meta}; - message.set_data(finish_message.SerializeAsString()); - if (!SendMessageSync(client, message)) { - MS_LOG(EXCEPTION) << "Disconnect timeout!"; - } - MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; - return WaitForDisconnect(timeout); -} - bool Node::WaitForStart(const uint32_t &timeout) { std::unique_lock lock(wait_start_mutex_); bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { @@ -242,17 +162,6 @@ bool Node::WaitForStart(const uint32_t &timeout) { return res; } -bool Node::WaitForDisconnect(const uint32_t &timeout) { - std::unique_lock lock(wait_finish_mutex_); - bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { - if (is_finish_.load()) { - MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; - } - return is_finish_.load(); - }); - return res; -} - bool Node::SendMessageSync(const std::shared_ptr &client, const CommMessage &message, const uint32_t &timeout) { uint64_t request_id = ++next_request_id_; @@ -268,15 +177,6 @@ void Node::SendMessageAsync(const std::shared_ptr &client, const Comm client->SendMessage(message); } -void Node::NotifyMessageArrival(const CommMessage &message) { - std::lock_guard lock(message_tracker_mutex_); - const MessageMeta &message_meta = message.pb_meta(); - uint64_t request_id = message_meta.request_id(); - - message_tracker_[request_id].second++; - message_tracker_cond_.notify_all(); -} - const std::shared_ptr &Node::GetOrCreateTcpClient(const int &rank_id) { std::lock_guard lock(client_mutex_); if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { @@ -292,6 +192,7 @@ const std::shared_ptr &Node::GetOrCreateTcpClient(const int &rank_id) switch (message.pb_meta().cmd()) { case NodeCommand::SEND_DATA: ProcessSendDataResp(message); + RunMessageCallback(message.pb_meta().request_id()); break; default: MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; @@ -317,13 +218,13 @@ void Node::ProcessSendDataResp(const CommMessage &message) { res.insert(std::make_pair(rank_id, message)); receive_messages_[request_id] = res; } - - RunMessageCallback(request_id); } void Node::RunMessageCallback(const uint64_t &request_id) { message_callbacks_mutex_.lock(); - if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) { + // When receiving a message's response, Then compare with the desired number of responses, + // If they are equal, then call the callback function + if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { auto it = message_callbacks_.find(request_id); if (it != message_callbacks_.end()) { message_callbacks_mutex_.unlock(); @@ -346,6 +247,15 @@ void Node::set_message_callback(const uint64_t &request_id, const MessageCallbac std::lock_guard lock(message_callbacks_mutex_); message_callbacks_[request_id] = message_callback; } + +void Node::NotifyMessageArrival(const CommMessage &message) { + std::lock_guard lock(message_tracker_mutex_); + const MessageMeta &message_meta = message.pb_meta(); + uint64_t request_id = message_meta.request_id(); + + message_tracker_[request_id].second++; + message_tracker_cond_.notify_all(); +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index cecebbb229..2f8190f2d3 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -52,8 +52,7 @@ class Node { is_timeout_(false), is_already_stopped_(true), is_already_finished_(false), - next_request_id_(0), - heart_beat_thread_(nullptr) {} + next_request_id_(0) {} virtual ~Node() = default; using OnNodeEventMessage = std::function; @@ -63,9 +62,10 @@ class Node { virtual bool Stop() = 0; virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; - void set_callback(const OnNodeEventMessage &on_node_event_message); std::string node_id() const; uint32_t rank_id() const; + NodeRole role() const; + bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, @@ -73,27 +73,21 @@ class Node { virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, const uint32_t &timeout = kCommTimeoutInSeconds); virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); + CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, std::vector *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); protected: - void Heartbeat(const std::shared_ptr &client); - void ProcessHeartbeatResp(const CommMessage &message); - void FetchServers(const std::shared_ptr &client); - void ProcessFetchServersResp(const CommMessage &message); - bool Disconnect(const std::shared_ptr &client, const uint32_t &timeout); bool WaitForStart(const uint32_t &timeout); - bool WaitForDisconnect(const uint32_t &timeout); bool SendMessageSync(const std::shared_ptr &client, const CommMessage &message, const uint32_t &timeout = kCommTimeoutInSeconds); void SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); - void NotifyMessageArrival(const CommMessage &message); const std::shared_ptr &GetOrCreateTcpClient(const int &rank_id); void ProcessSendDataResp(const CommMessage &message); void RunMessageCallback(const uint64_t &request_id); void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); + void NotifyMessageArrival(const CommMessage &message); NodeInfo node_info_; std::atomic is_ready_; @@ -102,9 +96,6 @@ class Node { std::atomic is_already_stopped_; std::atomic is_already_finished_; std::atomic_uint64_t next_request_id_; - std::unique_ptr heart_beat_thread_; - - OnNodeEventMessage on_node_event_message_; // -> std::map, std::pair> nodes_address_; @@ -132,5 +123,4 @@ class Node { } // namespace core } // namespace ps } // namespace mindspore - #endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc new file mode 100644 index 0000000000..987451666e --- /dev/null +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ps/core/server_node.h" + +namespace mindspore { +namespace ps { +namespace core { +ServerNode::~ServerNode() { + MS_LOG(INFO) << "Stop server node!"; + if (!is_already_stopped_.load()) { + server_->Stop(); + client_to_scheduler_->Stop(); + client_to_scheduler_->StopEventBase(); + if (server_thread_->joinable()) { + server_thread_->join(); + } + if (client_to_scheduler_thread_->joinable()) { + client_to_scheduler_thread_->join(); + } + is_already_stopped_ = true; + } +} + +bool ServerNode::Start(const uint32_t &timeout) { + MS_LOG(INFO) << "Start server node!"; + Initialize(); + Register(client_to_scheduler_); + Heartbeat(client_to_scheduler_); + + if (!WaitForStart(timeout)) { + MS_LOG(EXCEPTION) << "Start Worker node timeout!"; + } + MS_LOG(INFO) << "The cluster is ready to use!"; + + // If the cluster is ready to use, then Get the address of all the servers + if (!is_timeout_.load()) { + FetchServers(client_to_scheduler_); + MS_LOG(INFO) << "Server node get all the servers address successful!"; + } + MS_LOG(INFO) << "Start the node is successful!"; + return true; +} + +void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } + +void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, + const std::string &message) { + auto &meta = const_cast(message_meta); + meta.set_role(node_info_.node_role_); + meta.set_rank_id(node_info_.rank_id_); + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {meta}; + comm_message.set_data(message); + + const_cast(server).SendMessage(conn, comm_message); +} + +void ServerNode::CreateTcpServer() { + std::string interface; + std::string server_ip; + CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); + server_ = std::make_shared(server_ip, 0); + server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + switch (message.pb_meta().cmd()) { + case NodeCommand::SEND_DATA: + ProcessSendData(server, conn, message); + break; + default: + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + }); + server_->Init(); + server_thread_ = std::make_unique([&]() { + MS_LOG(INFO) << "The server node start a tcp server!"; + server_->Start(); + }); + server_thread_->detach(); +} + +void ServerNode::Initialize() { + CreateTcpServer(); + is_already_stopped_ = false; + node_info_.node_id_ = CommUtil::GenerateUUID(); + node_info_.node_role_ = NodeRole::SERVER; + node_info_.ip_ = server_->BoundIp(); + node_info_.port_ = server_->BoundPort(); + MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " is generate uuid is:" << node_info_.node_id_; + if (!InitClientToScheduler()) { + MS_LOG(EXCEPTION) << "Server node init client timeout!"; + } + MS_LOG(INFO) << "Server node init client successful!"; +} + +void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + if (request_handler_) { + request_handler_(server, conn, message.pb_meta(), message.data()); + } +} + +bool ServerNode::Stop() { + MS_LOG(INFO) << "Stop server node!"; + if (!is_already_stopped_.load()) { + server_->Stop(); + client_to_scheduler_->Stop(); + client_to_scheduler_->StopEventBase(); + if (server_thread_->joinable()) { + server_thread_->join(); + } + if (client_to_scheduler_thread_->joinable()) { + client_to_scheduler_thread_->join(); + } + if (heart_beat_thread_->joinable()) { + heart_beat_thread_->join(); + } + is_already_stopped_ = true; + } + return true; +} + +bool ServerNode::Finish(const uint32_t &timeout) { + std::lock_guard lock(finish_mutex_); + if (is_already_finished_) { + MS_LOG(INFO) << "Server node already finish!"; + return true; + } + is_already_finished_ = true; + return Disconnect(client_to_scheduler_, timeout); +} +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h new file mode 100644 index 0000000000..2c3d728dfa --- /dev/null +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ +#define MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ + +#include +#include +#include +#include +#include +#include + +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/core/cluster_config.h" +#include "ps/core/tcp_client.h" +#include "ps/core/tcp_server.h" +#include "ps/core/abstract_node.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace core { +class ServerNode : public AbstractNode { + public: + ServerNode() : server_(nullptr), server_thread_(nullptr) {} + ~ServerNode() override; + + bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; + bool Stop() override; + bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; + + using RequestHandler = std::function; + + void set_handler(const RequestHandler &handler); + void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, + const std::string &message); + + private: + void CreateTcpServer(); + void Initialize(); + void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + + std::shared_ptr server_; + std::unique_ptr server_thread_; + RequestHandler request_handler_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index 8868daa60e..a59431c257 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -35,7 +35,6 @@ namespace mindspore { namespace ps { namespace core { - event_base *TcpClient::event_base_ = nullptr; TcpClient::TcpClient(const std::string &address, std::uint16_t port) @@ -43,7 +42,8 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) buffer_event_(nullptr), server_address_(std::move(address)), server_port_(port), - is_stop_(true) { + is_stop_(true), + is_connected_(false) { message_handler_.SetCallback([this](const CommMessage &message) { if (message_callback_) { message_callback_(*this, message); @@ -55,12 +55,15 @@ TcpClient::~TcpClient() { Stop(); } std::string TcpClient::GetServerAddress() const { return server_address_; } -void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, - const OnTimeout &timeout) { - connected_callback_ = conn; - disconnected_callback_ = disconn; - read_callback_ = read; - timeout_callback_ = timeout; +void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; } + +void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; } + +bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { + std::unique_lock lock(connection_mutex_); + bool res = + connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); }); + return res; } void TcpClient::Init() { @@ -68,6 +71,7 @@ void TcpClient::Init() { if (buffer_event_) { return; } + is_stop_ = false; if (!CommUtil::CheckIp(server_address_)) { MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; } @@ -198,6 +202,12 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { } } +void TcpClient::NotifyConnected() { + MS_LOG(INFO) << "Client connected to the server!"; + is_connected_ = true; + connection_cond_.notify_all(); +} + void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { MS_EXCEPTION_IF_NULL(bev); MS_EXCEPTION_IF_NULL(ptr); @@ -205,27 +215,24 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void if (events & BEV_EVENT_CONNECTED) { // Connected if (tcp_client->connected_callback_) { - tcp_client->connected_callback_(*tcp_client); + tcp_client->connected_callback_(); } - evutil_socket_t fd = bufferevent_getfd(const_cast(bev)); + tcp_client->NotifyConnected(); + evutil_socket_t fd = bufferevent_getfd(bev); SetTcpNoDelay(fd); MS_LOG(INFO) << "Client connected!"; } else if (events & BEV_EVENT_ERROR) { MS_LOG(ERROR) << "Client connected error!"; if (tcp_client->disconnected_callback_) { - tcp_client->disconnected_callback_(*tcp_client, errno); + tcp_client->disconnected_callback_(); } } else if (events & BEV_EVENT_EOF) { MS_LOG(ERROR) << "Client connected end of file"; - if (tcp_client->disconnected_callback_) { - tcp_client->disconnected_callback_(*tcp_client, 0); - } } } void TcpClient::Start() { MS_EXCEPTION_IF_NULL(event_base_); - is_stop_ = false; int ret = event_base_dispatch(event_base_); MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) diff --git a/mindspore/ccsrc/ps/core/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h index d98738b532..d18a6b48b6 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -30,19 +30,19 @@ #include #include #include +#include +#include "ps/core/cluster_config.h" #include "proto/comm.pb.h" #include "proto/ps.pb.h" -#include "ps/core/cluster_config.h" namespace mindspore { namespace ps { namespace core { - class TcpClient { public: - using OnConnected = std::function; - using OnDisconnected = std::function; + using OnConnected = std::function; + using OnDisconnected = std::function; using OnRead = std::function; using OnTimeout = std::function; using OnMessage = std::function; @@ -52,8 +52,9 @@ class TcpClient { virtual ~TcpClient(); std::string GetServerAddress() const; - void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, - const OnTimeout &timeout); + void set_disconnected_callback(const OnDisconnected &disconnected); + void set_connected_callback(const OnConnected &connected); + bool WaitConnected(const uint32_t &connected_timeout = ClusterConfig::cluster_available_timeout()); void Init(); void StartWithDelay(int seconds); void Stop(); @@ -73,6 +74,7 @@ class TcpClient { static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); virtual void OnReadHandler(const void *buf, size_t num); static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); + void NotifyConnected(); private: OnMessage message_callback_; @@ -86,12 +88,14 @@ class TcpClient { static event_base *event_base_; std::mutex connection_mutex_; + std::condition_variable connection_cond_; event *event_timeout_; bufferevent *buffer_event_; std::string server_address_; std::uint16_t server_port_; std::atomic is_stop_; + std::atomic is_connected_; }; } // namespace core diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index cefc344f2e..af33b2e7d5 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -95,6 +95,7 @@ void TcpServer::Init() { MS_LOG(EXCEPTION) << "Use event pthread failed!"; } + is_stop_ = false; base_ = event_base_new(); MS_EXCEPTION_IF_NULL(base_); if (!CommUtil::CheckIp(server_address_)) { @@ -138,7 +139,6 @@ void TcpServer::Start() { std::unique_lock lock(connection_mutex_); MS_LOG(INFO) << "Start tcp server!"; MS_EXCEPTION_IF_NULL(base_); - is_stop_ = false; int ret = event_base_dispatch(base_); MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) @@ -368,6 +368,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); } const std::map &TcpServer::Connections() const { return connections_; } void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index eb38475748..9e14bb7e65 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -31,8 +31,8 @@ WorkerNode::~WorkerNode() { } } client_to_scheduler_->StopEventBase(); - if (worker_thread_->joinable()) { - worker_thread_->join(); + if (client_to_scheduler_thread_->joinable()) { + client_to_scheduler_thread_->join(); } if (heart_beat_thread_->joinable()) { heart_beat_thread_->join(); @@ -43,7 +43,7 @@ WorkerNode::~WorkerNode() { bool WorkerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "Starting worker node!"; Initialize(); - Register(); + Register(client_to_scheduler_); Heartbeat(client_to_scheduler_); if (!WaitForStart(timeout)) { @@ -52,84 +52,25 @@ bool WorkerNode::Start(const uint32_t &timeout) { } MS_LOG(INFO) << "The node is ready to fetch servers!"; + // If the cluster is ready to use, then Get the address of all the servers if (!is_timeout_.load()) { FetchServers(client_to_scheduler_); - MS_LOG(INFO) << "Fetch servers successful!"; + MS_LOG(INFO) << "Worker node get all the servers address successful!"; } MS_LOG(INFO) << "The Worker node has successfully started."; return true; } -void WorkerNode::Register() { - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::REGISTER); - - RegisterMessage register_message; - register_message.set_node_id(node_info_.node_id_); - register_message.set_role(node_info_.node_role_); - - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(register_message.SerializeAsString()); - if (!SendMessageSync(client_to_scheduler_, comm_message)) { - MS_LOG(EXCEPTION) << "Worker node register timeout!"; - } - MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_ - << "is registering to scheduler, the request id is:" << message_meta.request_id(); -} - -void WorkerNode::ProcessRegisterResp(const CommMessage &message) { - RegisterRespMessage register_resp_message; - register_resp_message.ParseFromString(message.data()); - if (register_resp_message.node_id() != node_info_.node_id_) { - MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() - << " is not match the current node id:" << node_info_.node_id_; - } - - node_info_.rank_id_ = register_resp_message.rank_id(); - - MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; -} - void WorkerNode::Initialize() { is_already_stopped_ = false; node_info_.node_id_ = CommUtil::GenerateUUID(); node_info_.node_role_ = NodeRole::WORKER; MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_; - InitClientToScheduler(); -} - -void WorkerNode::InitClientToScheduler() { - std::string scheduler_host = ClusterConfig::scheduler_host(); - uint16_t scheduler_port = ClusterConfig::scheduler_port(); - client_to_scheduler_ = std::make_shared(scheduler_host, scheduler_port); - client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { - switch (message.pb_meta().cmd()) { - case NodeCommand::HEARTBEAT: - ProcessHeartbeatResp(message); - break; - case NodeCommand::REGISTER: - ProcessRegisterResp(message); - break; - case NodeCommand::FETCH_SERVER: - ProcessFetchServersResp(message); - break; - case NodeCommand::FINISH: - MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; - break; - default: - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; - } - NotifyMessageArrival(message); - }); - - client_to_scheduler_->Init(); - worker_thread_ = std::make_unique([&]() { - MS_LOG(INFO) << "The worker node start a tcp client!"; - client_to_scheduler_->Start(); - }); - worker_thread_->detach(); + if (!InitClientToScheduler()) { + MS_LOG(EXCEPTION) << "Worker node init client timeout!"; + } + MS_LOG(INFO) << "Worker node init client successful!"; } bool WorkerNode::Stop() { @@ -144,8 +85,8 @@ bool WorkerNode::Stop() { } } client_to_scheduler_->StopEventBase(); - if (worker_thread_->joinable()) { - worker_thread_->join(); + if (client_to_scheduler_thread_->joinable()) { + client_to_scheduler_thread_->join(); } if (heart_beat_thread_->joinable()) { heart_beat_thread_->join(); @@ -165,23 +106,6 @@ bool WorkerNode::Finish(const uint32_t &timeout) { is_already_finished_ = true; return Disconnect(client_to_scheduler_, timeout); } - -bool WorkerNode::BroadcastToServers(const std::string &message) { - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); - for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); - auto client = GetOrCreateTcpClient((*it).first.second); - client->SendMessage(comm_message); - } - return Wait(request_id); -} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h index 32f6622fa5..fe6ac67539 100644 --- a/mindspore/ccsrc/ps/core/worker_node.h +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -17,50 +17,35 @@ #ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ #define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ -#include #include #include #include #include -#include -#include -#include #include -#include #include -#include #include "proto/comm.pb.h" #include "proto/ps.pb.h" #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" -#include "ps/core/node.h" +#include "ps/core/abstract_node.h" #include "utils/log_adapter.h" namespace mindspore { namespace ps { namespace core { -class WorkerNode : public Node { +class WorkerNode : public AbstractNode { public: - WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {} + WorkerNode() = default; ~WorkerNode() override; bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; - bool BroadcastToServers(const std::string &message); - private: - void Register(); - void ProcessRegisterResp(const CommMessage &message); - void Initialize(); - void InitClientToScheduler(); - - std::shared_ptr client_to_scheduler_; - std::unique_ptr worker_thread_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/embedding_table_shard_metadata.cc b/mindspore/ccsrc/ps/embedding_table_shard_metadata.cc new file mode 100644 index 0000000000..ea4d977f4e --- /dev/null +++ b/mindspore/ccsrc/ps/embedding_table_shard_metadata.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/embedding_table_shard_metadata.h" + +namespace mindspore { +namespace ps { +uint64_t EmbeddingTableShardMetadata::begin() const { return begin_; } + +uint64_t EmbeddingTableShardMetadata::end() const { return end_; } + +uint64_t EmbeddingTableShardMetadata::size() const { return end_ - begin_; } +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/embedding_table_shard_metadata.h b/mindspore/ccsrc/ps/embedding_table_shard_metadata.h new file mode 100644 index 0000000000..11769ac390 --- /dev/null +++ b/mindspore/ccsrc/ps/embedding_table_shard_metadata.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ +#define MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ + +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +class EmbeddingTableShardMetadata { + public: + explicit EmbeddingTableShardMetadata(uint64_t begin, uint64_t end) : begin_(begin), end_(end) {} + virtual ~EmbeddingTableShardMetadata() = default; + + uint64_t begin() const; + uint64_t end() const; + uint64_t size() const; + + private: + uint64_t begin_; + uint64_t end_; +}; +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ diff --git a/tests/ut/cpp/ps/embedding_table_shard_metadata_test.cc b/tests/ut/cpp/ps/embedding_table_shard_metadata_test.cc new file mode 100644 index 0000000000..a0edff8506 --- /dev/null +++ b/tests/ut/cpp/ps/embedding_table_shard_metadata_test.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "ps/embedding_table_shard_metadata.h" + +namespace mindspore { +namespace ps { +class TestEmbeddingTableShardMetadata : public UT::Common { + public: + TestEmbeddingTableShardMetadata() = default; + virtual ~TestEmbeddingTableShardMetadata() = default; + + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(TestEmbeddingTableShardMetadata, EmbeddingTable) { + EmbeddingTableShardMetadata embedding_table_shard(1, 100); + EXPECT_EQ(embedding_table_shard.begin(), 1); + EXPECT_EQ(embedding_table_shard.end(), 100); + EXPECT_EQ(embedding_table_shard.size(), 99); +} +} // namespace ps +} // namespace mindspore