| @@ -5,6 +5,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "util.cc") | list(REMOVE_ITEM _PS_SRC_FILES "util.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "embedding_table_shard_metadata.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc") | ||||
| @@ -16,6 +17,8 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") | |||||
| endif () | endif () | ||||
| if (NOT ENABLE_D) | if (NOT ENABLE_D) | ||||
| @@ -0,0 +1,212 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/core/abstract_node.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::REGISTER); | |||||
| RegisterMessage register_message; | |||||
| register_message.set_node_id(node_info_.node_id_); | |||||
| register_message.set_role(node_info_.node_role_); | |||||
| register_message.set_ip(node_info_.ip_); | |||||
| register_message.set_port(node_info_.port_); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(register_message.SerializeAsString()); | |||||
| if (!SendMessageSync(client, comm_message)) { | |||||
| MS_LOG(EXCEPTION) << "Node register timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ | |||||
| << "is registering to scheduler, the request id is:" << message_meta.request_id(); | |||||
| } | |||||
| void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | |||||
| RegisterRespMessage register_resp_message; | |||||
| register_resp_message.ParseFromString(message.data()); | |||||
| if (register_resp_message.node_id() != node_info_.node_id_) { | |||||
| MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() | |||||
| << " is not match the current node id:" << node_info_.node_id_; | |||||
| } | |||||
| node_info_.rank_id_ = register_resp_message.rank_id(); | |||||
| MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||||
| } | |||||
| bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) { | |||||
| uint64_t request_id = ++next_request_id_; | |||||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||||
| message_meta.set_request_id(request_id); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(message); | |||||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||||
| client->SendMessage(comm_message); | |||||
| } | |||||
| return Wait(request_id, timeout); | |||||
| } | |||||
| void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_message) { | |||||
| on_node_event_message_ = on_node_event_message; | |||||
| } | |||||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) { | |||||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||||
| << " begin send heartbeat to the scheduler!"; | |||||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | |||||
| while (!is_finish_.load()) { | |||||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||||
| MessageMeta meta; | |||||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||||
| HeartbeatMessage heartbeat_message; | |||||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||||
| CommMessage message; | |||||
| *message.mutable_pb_meta() = {meta}; | |||||
| message.set_data(heartbeat_message.SerializeAsString()); | |||||
| if (!SendMessageSync(client, message)) { | |||||
| MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||||
| } | |||||
| } | |||||
| }); | |||||
| heart_beat_thread_->detach(); | |||||
| } | |||||
| void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||||
| HeartbeatRespMessage heartbeat_resp_message; | |||||
| heartbeat_resp_message.ParseFromString(message.data()); | |||||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||||
| if (is_ready_.load()) { | |||||
| wait_start_cond_.notify_all(); | |||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; | |||||
| } | |||||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | |||||
| if (is_finish_.load()) { | |||||
| wait_finish_cond_.notify_all(); | |||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; | |||||
| } | |||||
| is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | |||||
| if (is_timeout_ && on_node_event_message_) { | |||||
| is_ready_ = true; | |||||
| wait_start_cond_.notify_all(); | |||||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | |||||
| } | |||||
| } | |||||
| void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||||
| MessageMeta meta; | |||||
| meta.set_cmd(NodeCommand::FETCH_SERVER); | |||||
| CommMessage message; | |||||
| *message.mutable_pb_meta() = {meta}; | |||||
| if (!SendMessageSync(client, message)) { | |||||
| MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; | |||||
| } | |||||
| } | |||||
| void AbstractNode::ProcessFetchServersResp(const CommMessage &message) { | |||||
| FetchServersRespMessage fetch_servers_resp_message; | |||||
| fetch_servers_resp_message.ParseFromString(message.data()); | |||||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | |||||
| nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); | |||||
| } | |||||
| MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); | |||||
| } | |||||
| bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { | |||||
| MessageMeta meta; | |||||
| meta.set_cmd(NodeCommand::FINISH); | |||||
| FinishMessage finish_message; | |||||
| finish_message.set_node_id(node_info_.node_id_); | |||||
| CommMessage message; | |||||
| *message.mutable_pb_meta() = {meta}; | |||||
| message.set_data(finish_message.SerializeAsString()); | |||||
| if (!SendMessageSync(client, message)) { | |||||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||||
| return WaitForDisconnect(timeout); | |||||
| } | |||||
| bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { | |||||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | |||||
| bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||||
| if (is_finish_.load()) { | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||||
| } | |||||
| return is_finish_.load(); | |||||
| }); | |||||
| return res; | |||||
| } | |||||
| bool AbstractNode::InitClientToScheduler() { | |||||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||||
| uint16_t scheduler_port = ClusterConfig::scheduler_port(); | |||||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); | |||||
| client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||||
| switch (message.pb_meta().cmd()) { | |||||
| case NodeCommand::HEARTBEAT: | |||||
| ProcessHeartbeatResp(message); | |||||
| break; | |||||
| case NodeCommand::REGISTER: | |||||
| ProcessRegisterResp(message); | |||||
| break; | |||||
| case NodeCommand::FETCH_SERVER: | |||||
| ProcessFetchServersResp(message); | |||||
| break; | |||||
| case NodeCommand::FINISH: | |||||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||||
| } | |||||
| NotifyMessageArrival(message); | |||||
| }); | |||||
| client_to_scheduler_->Init(); | |||||
| client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() { | |||||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||||
| client_to_scheduler_->Start(); | |||||
| }); | |||||
| client_to_scheduler_thread_->detach(); | |||||
| client_to_scheduler_->set_disconnected_callback([&]() { | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval())); | |||||
| client_to_scheduler_->Stop(); | |||||
| client_to_scheduler_->Init(); | |||||
| }); | |||||
| return client_to_scheduler_->WaitConnected(); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ | |||||
| #include <utility> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ps/core/node.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class AbstractNode : public Node { | |||||
| public: | |||||
| AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} | |||||
| ~AbstractNode() override = default; | |||||
| bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| void set_event_callback(const OnNodeEventMessage &on_node_event_message); | |||||
| protected: | |||||
| void Register(const std::shared_ptr<TcpClient> &client); | |||||
| void ProcessRegisterResp(const CommMessage &message); | |||||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | |||||
| void ProcessHeartbeatResp(const CommMessage &message); | |||||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||||
| void ProcessFetchServersResp(const CommMessage &message); | |||||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||||
| bool WaitForDisconnect(const uint32_t &timeout); | |||||
| bool InitClientToScheduler(); | |||||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||||
| std::unique_ptr<std::thread> client_to_scheduler_thread_; | |||||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||||
| OnNodeEventMessage on_node_event_message_; | |||||
| }; | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ | |||||
| @@ -31,6 +31,8 @@ uint32_t ClusterConfig::heartbeat_interval_ = 3; | |||||
| uint32_t ClusterConfig::heartbeat_timeout_ = 30; | uint32_t ClusterConfig::heartbeat_timeout_ = 30; | ||||
| // Timeout period for cluster preparation is 300 seconds. | // Timeout period for cluster preparation is 300 seconds. | ||||
| uint32_t ClusterConfig::cluster_available_timeout_ = 300; | uint32_t ClusterConfig::cluster_available_timeout_ = 300; | ||||
| // The timeout period for the client to connect to the server is 100ms. | |||||
| uint32_t ClusterConfig::connect_interval_ = 100; | |||||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | ||||
| std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | ||||
| @@ -69,6 +71,9 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa | |||||
| cluster_available_timeout_ = cluster_available_timeout; | cluster_available_timeout_ = cluster_available_timeout; | ||||
| } | } | ||||
| uint32_t ClusterConfig::connect_interval() { return connect_interval_; } | |||||
| void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,6 +42,8 @@ class ClusterConfig { | |||||
| static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout); | static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout); | ||||
| static uint32_t cluster_available_timeout(); | static uint32_t cluster_available_timeout(); | ||||
| static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); | static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); | ||||
| static uint32_t connect_interval(); | |||||
| static void set_connect_interval(const uint32_t &connect_interval); | |||||
| private: | private: | ||||
| static uint32_t worker_num_; | static uint32_t worker_num_; | ||||
| @@ -51,6 +53,7 @@ class ClusterConfig { | |||||
| static uint16_t scheduler_port_; | static uint16_t scheduler_port_; | ||||
| static uint32_t heartbeat_timeout_; | static uint32_t heartbeat_timeout_; | ||||
| static uint32_t cluster_available_timeout_; | static uint32_t cluster_available_timeout_; | ||||
| static uint32_t connect_interval_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -19,78 +19,11 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| void Node::Heartbeat(const std::shared_ptr<TcpClient> &client) { | |||||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||||
| << " begin send heartbeat to the scheduler!"; | |||||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | |||||
| while (!is_finish_.load()) { | |||||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||||
| MessageMeta meta; | |||||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||||
| HeartbeatMessage heartbeat_message; | |||||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||||
| CommMessage message; | |||||
| *message.mutable_pb_meta() = {meta}; | |||||
| message.set_data(heartbeat_message.SerializeAsString()); | |||||
| SendMessageAsync(client, message); | |||||
| } | |||||
| }); | |||||
| heart_beat_thread_->detach(); | |||||
| } | |||||
| void Node::ProcessHeartbeatResp(const CommMessage &message) { | |||||
| HeartbeatRespMessage heartbeat_resp_message; | |||||
| heartbeat_resp_message.ParseFromString(message.data()); | |||||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||||
| if (is_ready_.load()) { | |||||
| wait_start_cond_.notify_all(); | |||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; | |||||
| } | |||||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | |||||
| if (is_finish_.load()) { | |||||
| wait_finish_cond_.notify_all(); | |||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; | |||||
| } | |||||
| is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | |||||
| if (is_timeout_ && on_node_event_message_) { | |||||
| is_ready_ = true; | |||||
| wait_start_cond_.notify_all(); | |||||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | |||||
| } | |||||
| } | |||||
| void Node::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||||
| MessageMeta meta; | |||||
| meta.set_cmd(NodeCommand::FETCH_SERVER); | |||||
| CommMessage message; | |||||
| *message.mutable_pb_meta() = {meta}; | |||||
| if (!SendMessageSync(client, message)) { | |||||
| MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; | |||||
| } | |||||
| } | |||||
| void Node::ProcessFetchServersResp(const CommMessage &message) { | |||||
| FetchServersRespMessage fetch_servers_resp_message; | |||||
| fetch_servers_resp_message.ParseFromString(message.data()); | |||||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | |||||
| nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); | |||||
| } | |||||
| MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); | |||||
| } | |||||
| std::string Node::node_id() const { return node_info_.node_id_; } | std::string Node::node_id() const { return node_info_.node_id_; } | ||||
| uint32_t Node::rank_id() const { return node_info_.rank_id_; } | uint32_t Node::rank_id() const { return node_info_.rank_id_; } | ||||
| void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { | |||||
| on_node_event_message_ = on_node_event_message; | |||||
| } | |||||
| NodeRole Node::role() const { return node_info_.node_role_; } | |||||
| bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { | bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { | ||||
| std::unique_lock<std::mutex> lock(message_tracker_mutex_); | std::unique_lock<std::mutex> lock(message_tracker_mutex_); | ||||
| @@ -147,6 +80,7 @@ bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids | |||||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | ||||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | CommMessage *comm_message_resp, const uint32_t &timeout) { | ||||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | if (!CommUtil::ValidateRankId(node_role, rank_id)) { | ||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | ||||
| } | } | ||||
| @@ -156,7 +90,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s | |||||
| set_message_callback(request_id, [&]() { | set_message_callback(request_id, [&]() { | ||||
| receive_messages_mutex_.lock(); | receive_messages_mutex_.lock(); | ||||
| auto res = receive_messages_[request_id]; | auto res = receive_messages_[request_id]; | ||||
| comm_message_resp = &res[rank_id]; | |||||
| *comm_message_resp = res[rank_id]; | |||||
| receive_messages_.erase(request_id); | receive_messages_.erase(request_id); | ||||
| receive_messages_mutex_.unlock(); | receive_messages_mutex_.unlock(); | ||||
| }); | }); | ||||
| @@ -164,6 +98,8 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s | |||||
| MessageMeta message_meta; | MessageMeta message_meta; | ||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | message_meta.set_cmd(NodeCommand::SEND_DATA); | ||||
| message_meta.set_request_id(request_id); | message_meta.set_request_id(request_id); | ||||
| message_meta.set_rank_id(node_info_.rank_id_); | |||||
| message_meta.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | CommMessage comm_message; | ||||
| *comm_message.mutable_pb_meta() = {message_meta}; | *comm_message.mutable_pb_meta() = {message_meta}; | ||||
| @@ -175,6 +111,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s | |||||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | ||||
| std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) { | std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) { | ||||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | message_tracker_[request_id] = std::make_pair(data.size(), 0); | ||||
| @@ -213,23 +150,6 @@ bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids | |||||
| return Wait(request_id, timeout); | return Wait(request_id, timeout); | ||||
| } | } | ||||
| bool Node::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { | |||||
| MessageMeta meta; | |||||
| meta.set_cmd(NodeCommand::FINISH); | |||||
| FinishMessage finish_message; | |||||
| finish_message.set_node_id(node_info_.node_id_); | |||||
| CommMessage message; | |||||
| *message.mutable_pb_meta() = {meta}; | |||||
| message.set_data(finish_message.SerializeAsString()); | |||||
| if (!SendMessageSync(client, message)) { | |||||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||||
| return WaitForDisconnect(timeout); | |||||
| } | |||||
| bool Node::WaitForStart(const uint32_t &timeout) { | bool Node::WaitForStart(const uint32_t &timeout) { | ||||
| std::unique_lock<std::mutex> lock(wait_start_mutex_); | std::unique_lock<std::mutex> lock(wait_start_mutex_); | ||||
| bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | ||||
| @@ -242,17 +162,6 @@ bool Node::WaitForStart(const uint32_t &timeout) { | |||||
| return res; | return res; | ||||
| } | } | ||||
| bool Node::WaitForDisconnect(const uint32_t &timeout) { | |||||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | |||||
| bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||||
| if (is_finish_.load()) { | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||||
| } | |||||
| return is_finish_.load(); | |||||
| }); | |||||
| return res; | |||||
| } | |||||
| bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | ||||
| const uint32_t &timeout) { | const uint32_t &timeout) { | ||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| @@ -268,15 +177,6 @@ void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const Comm | |||||
| client->SendMessage(message); | client->SendMessage(message); | ||||
| } | } | ||||
| void Node::NotifyMessageArrival(const CommMessage &message) { | |||||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||||
| const MessageMeta &message_meta = message.pb_meta(); | |||||
| uint64_t request_id = message_meta.request_id(); | |||||
| message_tracker_[request_id].second++; | |||||
| message_tracker_cond_.notify_all(); | |||||
| } | |||||
| const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) { | const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) { | ||||
| std::lock_guard<std::mutex> lock(client_mutex_); | std::lock_guard<std::mutex> lock(client_mutex_); | ||||
| if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { | if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { | ||||
| @@ -292,6 +192,7 @@ const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) | |||||
| switch (message.pb_meta().cmd()) { | switch (message.pb_meta().cmd()) { | ||||
| case NodeCommand::SEND_DATA: | case NodeCommand::SEND_DATA: | ||||
| ProcessSendDataResp(message); | ProcessSendDataResp(message); | ||||
| RunMessageCallback(message.pb_meta().request_id()); | |||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | ||||
| @@ -317,13 +218,13 @@ void Node::ProcessSendDataResp(const CommMessage &message) { | |||||
| res.insert(std::make_pair(rank_id, message)); | res.insert(std::make_pair(rank_id, message)); | ||||
| receive_messages_[request_id] = res; | receive_messages_[request_id] = res; | ||||
| } | } | ||||
| RunMessageCallback(request_id); | |||||
| } | } | ||||
| void Node::RunMessageCallback(const uint64_t &request_id) { | void Node::RunMessageCallback(const uint64_t &request_id) { | ||||
| message_callbacks_mutex_.lock(); | message_callbacks_mutex_.lock(); | ||||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) { | |||||
| // When receiving a message's response, Then compare with the desired number of responses, | |||||
| // If they are equal, then call the callback function | |||||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { | |||||
| auto it = message_callbacks_.find(request_id); | auto it = message_callbacks_.find(request_id); | ||||
| if (it != message_callbacks_.end()) { | if (it != message_callbacks_.end()) { | ||||
| message_callbacks_mutex_.unlock(); | message_callbacks_mutex_.unlock(); | ||||
| @@ -346,6 +247,15 @@ void Node::set_message_callback(const uint64_t &request_id, const MessageCallbac | |||||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | ||||
| message_callbacks_[request_id] = message_callback; | message_callbacks_[request_id] = message_callback; | ||||
| } | } | ||||
| void Node::NotifyMessageArrival(const CommMessage &message) { | |||||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||||
| const MessageMeta &message_meta = message.pb_meta(); | |||||
| uint64_t request_id = message_meta.request_id(); | |||||
| message_tracker_[request_id].second++; | |||||
| message_tracker_cond_.notify_all(); | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -52,8 +52,7 @@ class Node { | |||||
| is_timeout_(false), | is_timeout_(false), | ||||
| is_already_stopped_(true), | is_already_stopped_(true), | ||||
| is_already_finished_(false), | is_already_finished_(false), | ||||
| next_request_id_(0), | |||||
| heart_beat_thread_(nullptr) {} | |||||
| next_request_id_(0) {} | |||||
| virtual ~Node() = default; | virtual ~Node() = default; | ||||
| using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | ||||
| @@ -63,9 +62,10 @@ class Node { | |||||
| virtual bool Stop() = 0; | virtual bool Stop() = 0; | ||||
| virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; | virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; | ||||
| void set_callback(const OnNodeEventMessage &on_node_event_message); | |||||
| std::string node_id() const; | std::string node_id() const; | ||||
| uint32_t rank_id() const; | uint32_t rank_id() const; | ||||
| NodeRole role() const; | |||||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | ||||
| @@ -73,27 +73,21 @@ class Node { | |||||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | ||||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | ||||
| CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | ||||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | ||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| protected: | protected: | ||||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | |||||
| void ProcessHeartbeatResp(const CommMessage &message); | |||||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||||
| void ProcessFetchServersResp(const CommMessage &message); | |||||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||||
| bool WaitForStart(const uint32_t &timeout); | bool WaitForStart(const uint32_t &timeout); | ||||
| bool WaitForDisconnect(const uint32_t &timeout); | |||||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | ||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | ||||
| void NotifyMessageArrival(const CommMessage &message); | |||||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | ||||
| void ProcessSendDataResp(const CommMessage &message); | void ProcessSendDataResp(const CommMessage &message); | ||||
| void RunMessageCallback(const uint64_t &request_id); | void RunMessageCallback(const uint64_t &request_id); | ||||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | ||||
| void NotifyMessageArrival(const CommMessage &message); | |||||
| NodeInfo node_info_; | NodeInfo node_info_; | ||||
| std::atomic<bool> is_ready_; | std::atomic<bool> is_ready_; | ||||
| @@ -102,9 +96,6 @@ class Node { | |||||
| std::atomic<bool> is_already_stopped_; | std::atomic<bool> is_already_stopped_; | ||||
| std::atomic<bool> is_already_finished_; | std::atomic<bool> is_already_finished_; | ||||
| std::atomic_uint64_t next_request_id_; | std::atomic_uint64_t next_request_id_; | ||||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||||
| OnNodeEventMessage on_node_event_message_; | |||||
| // <NodeRole,rank_id>-><ip, port> | // <NodeRole,rank_id>-><ip, port> | ||||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | ||||
| @@ -132,5 +123,4 @@ class Node { | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ | #endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ | ||||
| @@ -0,0 +1,145 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/core/server_node.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| ServerNode::~ServerNode() { | |||||
| MS_LOG(INFO) << "Stop server node!"; | |||||
| if (!is_already_stopped_.load()) { | |||||
| server_->Stop(); | |||||
| client_to_scheduler_->Stop(); | |||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (server_thread_->joinable()) { | |||||
| server_thread_->join(); | |||||
| } | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | |||||
| } | |||||
| } | |||||
| bool ServerNode::Start(const uint32_t &timeout) { | |||||
| MS_LOG(INFO) << "Start server node!"; | |||||
| Initialize(); | |||||
| Register(client_to_scheduler_); | |||||
| Heartbeat(client_to_scheduler_); | |||||
| if (!WaitForStart(timeout)) { | |||||
| MS_LOG(EXCEPTION) << "Start Worker node timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The cluster is ready to use!"; | |||||
| // If the cluster is ready to use, then Get the address of all the servers | |||||
| if (!is_timeout_.load()) { | |||||
| FetchServers(client_to_scheduler_); | |||||
| MS_LOG(INFO) << "Server node get all the servers address successful!"; | |||||
| } | |||||
| MS_LOG(INFO) << "Start the node is successful!"; | |||||
| return true; | |||||
| } | |||||
| void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } | |||||
| void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||||
| const std::string &message) { | |||||
| auto &meta = const_cast<MessageMeta &>(message_meta); | |||||
| meta.set_role(node_info_.node_role_); | |||||
| meta.set_rank_id(node_info_.rank_id_); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {meta}; | |||||
| comm_message.set_data(message); | |||||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||||
| } | |||||
| void ServerNode::CreateTcpServer() { | |||||
| std::string interface; | |||||
| std::string server_ip; | |||||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); | |||||
| server_ = std::make_shared<TcpServer>(server_ip, 0); | |||||
| server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| switch (message.pb_meta().cmd()) { | |||||
| case NodeCommand::SEND_DATA: | |||||
| ProcessSendData(server, conn, message); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||||
| } | |||||
| }); | |||||
| server_->Init(); | |||||
| server_thread_ = std::make_unique<std::thread>([&]() { | |||||
| MS_LOG(INFO) << "The server node start a tcp server!"; | |||||
| server_->Start(); | |||||
| }); | |||||
| server_thread_->detach(); | |||||
| } | |||||
| void ServerNode::Initialize() { | |||||
| CreateTcpServer(); | |||||
| is_already_stopped_ = false; | |||||
| node_info_.node_id_ = CommUtil::GenerateUUID(); | |||||
| node_info_.node_role_ = NodeRole::SERVER; | |||||
| node_info_.ip_ = server_->BoundIp(); | |||||
| node_info_.port_ = server_->BoundPort(); | |||||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << " is generate uuid is:" << node_info_.node_id_; | |||||
| if (!InitClientToScheduler()) { | |||||
| MS_LOG(EXCEPTION) << "Server node init client timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "Server node init client successful!"; | |||||
| } | |||||
| void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| if (request_handler_) { | |||||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||||
| } | |||||
| } | |||||
| bool ServerNode::Stop() { | |||||
| MS_LOG(INFO) << "Stop server node!"; | |||||
| if (!is_already_stopped_.load()) { | |||||
| server_->Stop(); | |||||
| client_to_scheduler_->Stop(); | |||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (server_thread_->joinable()) { | |||||
| server_thread_->join(); | |||||
| } | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ServerNode::Finish(const uint32_t &timeout) { | |||||
| std::lock_guard<std::mutex> lock(finish_mutex_); | |||||
| if (is_already_finished_) { | |||||
| MS_LOG(INFO) << "Server node already finish!"; | |||||
| return true; | |||||
| } | |||||
| is_already_finished_ = true; | |||||
| return Disconnect(client_to_scheduler_, timeout); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ | |||||
| #include <cstdlib> | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <thread> | |||||
| #include <utility> | |||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/cluster_config.h" | |||||
| #include "ps/core/tcp_client.h" | |||||
| #include "ps/core/tcp_server.h" | |||||
| #include "ps/core/abstract_node.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class ServerNode : public AbstractNode { | |||||
| public: | |||||
| ServerNode() : server_(nullptr), server_thread_(nullptr) {} | |||||
| ~ServerNode() override; | |||||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||||
| bool Stop() override; | |||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, | |||||
| const MessageMeta message_meta, const std::string &message)>; | |||||
| void set_handler(const RequestHandler &handler); | |||||
| void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||||
| const std::string &message); | |||||
| private: | |||||
| void CreateTcpServer(); | |||||
| void Initialize(); | |||||
| void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| std::shared_ptr<TcpServer> server_; | |||||
| std::unique_ptr<std::thread> server_thread_; | |||||
| RequestHandler request_handler_; | |||||
| }; | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_ | |||||
| @@ -35,7 +35,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| event_base *TcpClient::event_base_ = nullptr; | event_base *TcpClient::event_base_ = nullptr; | ||||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port) | TcpClient::TcpClient(const std::string &address, std::uint16_t port) | ||||
| @@ -43,7 +42,8 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||||
| buffer_event_(nullptr), | buffer_event_(nullptr), | ||||
| server_address_(std::move(address)), | server_address_(std::move(address)), | ||||
| server_port_(port), | server_port_(port), | ||||
| is_stop_(true) { | |||||
| is_stop_(true), | |||||
| is_connected_(false) { | |||||
| message_handler_.SetCallback([this](const CommMessage &message) { | message_handler_.SetCallback([this](const CommMessage &message) { | ||||
| if (message_callback_) { | if (message_callback_) { | ||||
| message_callback_(*this, message); | message_callback_(*this, message); | ||||
| @@ -55,12 +55,15 @@ TcpClient::~TcpClient() { Stop(); } | |||||
| std::string TcpClient::GetServerAddress() const { return server_address_; } | std::string TcpClient::GetServerAddress() const { return server_address_; } | ||||
| void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, | |||||
| const OnTimeout &timeout) { | |||||
| connected_callback_ = conn; | |||||
| disconnected_callback_ = disconn; | |||||
| read_callback_ = read; | |||||
| timeout_callback_ = timeout; | |||||
| void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; } | |||||
| void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; } | |||||
| bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { | |||||
| std::unique_lock<std::mutex> lock(connection_mutex_); | |||||
| bool res = | |||||
| connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); }); | |||||
| return res; | |||||
| } | } | ||||
| void TcpClient::Init() { | void TcpClient::Init() { | ||||
| @@ -68,6 +71,7 @@ void TcpClient::Init() { | |||||
| if (buffer_event_) { | if (buffer_event_) { | ||||
| return; | return; | ||||
| } | } | ||||
| is_stop_ = false; | |||||
| if (!CommUtil::CheckIp(server_address_)) { | if (!CommUtil::CheckIp(server_address_)) { | ||||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | ||||
| } | } | ||||
| @@ -198,6 +202,12 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { | |||||
| } | } | ||||
| } | } | ||||
| void TcpClient::NotifyConnected() { | |||||
| MS_LOG(INFO) << "Client connected to the server!"; | |||||
| is_connected_ = true; | |||||
| connection_cond_.notify_all(); | |||||
| } | |||||
| void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | ||||
| MS_EXCEPTION_IF_NULL(bev); | MS_EXCEPTION_IF_NULL(bev); | ||||
| MS_EXCEPTION_IF_NULL(ptr); | MS_EXCEPTION_IF_NULL(ptr); | ||||
| @@ -205,27 +215,24 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||||
| if (events & BEV_EVENT_CONNECTED) { | if (events & BEV_EVENT_CONNECTED) { | ||||
| // Connected | // Connected | ||||
| if (tcp_client->connected_callback_) { | if (tcp_client->connected_callback_) { | ||||
| tcp_client->connected_callback_(*tcp_client); | |||||
| tcp_client->connected_callback_(); | |||||
| } | } | ||||
| evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev)); | |||||
| tcp_client->NotifyConnected(); | |||||
| evutil_socket_t fd = bufferevent_getfd(bev); | |||||
| SetTcpNoDelay(fd); | SetTcpNoDelay(fd); | ||||
| MS_LOG(INFO) << "Client connected!"; | MS_LOG(INFO) << "Client connected!"; | ||||
| } else if (events & BEV_EVENT_ERROR) { | } else if (events & BEV_EVENT_ERROR) { | ||||
| MS_LOG(ERROR) << "Client connected error!"; | MS_LOG(ERROR) << "Client connected error!"; | ||||
| if (tcp_client->disconnected_callback_) { | if (tcp_client->disconnected_callback_) { | ||||
| tcp_client->disconnected_callback_(*tcp_client, errno); | |||||
| tcp_client->disconnected_callback_(); | |||||
| } | } | ||||
| } else if (events & BEV_EVENT_EOF) { | } else if (events & BEV_EVENT_EOF) { | ||||
| MS_LOG(ERROR) << "Client connected end of file"; | MS_LOG(ERROR) << "Client connected end of file"; | ||||
| if (tcp_client->disconnected_callback_) { | |||||
| tcp_client->disconnected_callback_(*tcp_client, 0); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| void TcpClient::Start() { | void TcpClient::Start() { | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| is_stop_ = false; | |||||
| int ret = event_base_dispatch(event_base_); | int ret = event_base_dispatch(event_base_); | ||||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | ||||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | ||||
| @@ -30,19 +30,19 @@ | |||||
| #include <thread> | #include <thread> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <atomic> | #include <atomic> | ||||
| #include <condition_variable> | |||||
| #include "ps/core/cluster_config.h" | |||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| #include "ps/core/cluster_config.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| class TcpClient { | class TcpClient { | ||||
| public: | public: | ||||
| using OnConnected = std::function<void(const TcpClient &)>; | |||||
| using OnDisconnected = std::function<void(const TcpClient &, int)>; | |||||
| using OnConnected = std::function<void()>; | |||||
| using OnDisconnected = std::function<void()>; | |||||
| using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | ||||
| using OnTimeout = std::function<void(const TcpClient &)>; | using OnTimeout = std::function<void(const TcpClient &)>; | ||||
| using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; | using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; | ||||
| @@ -52,8 +52,9 @@ class TcpClient { | |||||
| virtual ~TcpClient(); | virtual ~TcpClient(); | ||||
| std::string GetServerAddress() const; | std::string GetServerAddress() const; | ||||
| void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, | |||||
| const OnTimeout &timeout); | |||||
| void set_disconnected_callback(const OnDisconnected &disconnected); | |||||
| void set_connected_callback(const OnConnected &connected); | |||||
| bool WaitConnected(const uint32_t &connected_timeout = ClusterConfig::cluster_available_timeout()); | |||||
| void Init(); | void Init(); | ||||
| void StartWithDelay(int seconds); | void StartWithDelay(int seconds); | ||||
| void Stop(); | void Stop(); | ||||
| @@ -73,6 +74,7 @@ class TcpClient { | |||||
| static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | ||||
| virtual void OnReadHandler(const void *buf, size_t num); | virtual void OnReadHandler(const void *buf, size_t num); | ||||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | ||||
| void NotifyConnected(); | |||||
| private: | private: | ||||
| OnMessage message_callback_; | OnMessage message_callback_; | ||||
| @@ -86,12 +88,14 @@ class TcpClient { | |||||
| static event_base *event_base_; | static event_base *event_base_; | ||||
| std::mutex connection_mutex_; | std::mutex connection_mutex_; | ||||
| std::condition_variable connection_cond_; | |||||
| event *event_timeout_; | event *event_timeout_; | ||||
| bufferevent *buffer_event_; | bufferevent *buffer_event_; | ||||
| std::string server_address_; | std::string server_address_; | ||||
| std::uint16_t server_port_; | std::uint16_t server_port_; | ||||
| std::atomic<bool> is_stop_; | std::atomic<bool> is_stop_; | ||||
| std::atomic<bool> is_connected_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| @@ -95,6 +95,7 @@ void TcpServer::Init() { | |||||
| MS_LOG(EXCEPTION) << "Use event pthread failed!"; | MS_LOG(EXCEPTION) << "Use event pthread failed!"; | ||||
| } | } | ||||
| is_stop_ = false; | |||||
| base_ = event_base_new(); | base_ = event_base_new(); | ||||
| MS_EXCEPTION_IF_NULL(base_); | MS_EXCEPTION_IF_NULL(base_); | ||||
| if (!CommUtil::CheckIp(server_address_)) { | if (!CommUtil::CheckIp(server_address_)) { | ||||
| @@ -138,7 +139,6 @@ void TcpServer::Start() { | |||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | ||||
| MS_LOG(INFO) << "Start tcp server!"; | MS_LOG(INFO) << "Start tcp server!"; | ||||
| MS_EXCEPTION_IF_NULL(base_); | MS_EXCEPTION_IF_NULL(base_); | ||||
| is_stop_ = false; | |||||
| int ret = event_base_dispatch(base_); | int ret = event_base_dispatch(base_); | ||||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | ||||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | ||||
| @@ -368,6 +368,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); } | |||||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | ||||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,8 +31,8 @@ WorkerNode::~WorkerNode() { | |||||
| } | } | ||||
| } | } | ||||
| client_to_scheduler_->StopEventBase(); | client_to_scheduler_->StopEventBase(); | ||||
| if (worker_thread_->joinable()) { | |||||
| worker_thread_->join(); | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | } | ||||
| if (heart_beat_thread_->joinable()) { | if (heart_beat_thread_->joinable()) { | ||||
| heart_beat_thread_->join(); | heart_beat_thread_->join(); | ||||
| @@ -43,7 +43,7 @@ WorkerNode::~WorkerNode() { | |||||
| bool WorkerNode::Start(const uint32_t &timeout) { | bool WorkerNode::Start(const uint32_t &timeout) { | ||||
| MS_LOG(INFO) << "Starting worker node!"; | MS_LOG(INFO) << "Starting worker node!"; | ||||
| Initialize(); | Initialize(); | ||||
| Register(); | |||||
| Register(client_to_scheduler_); | |||||
| Heartbeat(client_to_scheduler_); | Heartbeat(client_to_scheduler_); | ||||
| if (!WaitForStart(timeout)) { | if (!WaitForStart(timeout)) { | ||||
| @@ -52,84 +52,25 @@ bool WorkerNode::Start(const uint32_t &timeout) { | |||||
| } | } | ||||
| MS_LOG(INFO) << "The node is ready to fetch servers!"; | MS_LOG(INFO) << "The node is ready to fetch servers!"; | ||||
| // If the cluster is ready to use, then Get the address of all the servers | |||||
| if (!is_timeout_.load()) { | if (!is_timeout_.load()) { | ||||
| FetchServers(client_to_scheduler_); | FetchServers(client_to_scheduler_); | ||||
| MS_LOG(INFO) << "Fetch servers successful!"; | |||||
| MS_LOG(INFO) << "Worker node get all the servers address successful!"; | |||||
| } | } | ||||
| MS_LOG(INFO) << "The Worker node has successfully started."; | MS_LOG(INFO) << "The Worker node has successfully started."; | ||||
| return true; | return true; | ||||
| } | } | ||||
| void WorkerNode::Register() { | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::REGISTER); | |||||
| RegisterMessage register_message; | |||||
| register_message.set_node_id(node_info_.node_id_); | |||||
| register_message.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(register_message.SerializeAsString()); | |||||
| if (!SendMessageSync(client_to_scheduler_, comm_message)) { | |||||
| MS_LOG(EXCEPTION) << "Worker node register timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_ | |||||
| << "is registering to scheduler, the request id is:" << message_meta.request_id(); | |||||
| } | |||||
| void WorkerNode::ProcessRegisterResp(const CommMessage &message) { | |||||
| RegisterRespMessage register_resp_message; | |||||
| register_resp_message.ParseFromString(message.data()); | |||||
| if (register_resp_message.node_id() != node_info_.node_id_) { | |||||
| MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() | |||||
| << " is not match the current node id:" << node_info_.node_id_; | |||||
| } | |||||
| node_info_.rank_id_ = register_resp_message.rank_id(); | |||||
| MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||||
| } | |||||
| void WorkerNode::Initialize() { | void WorkerNode::Initialize() { | ||||
| is_already_stopped_ = false; | is_already_stopped_ = false; | ||||
| node_info_.node_id_ = CommUtil::GenerateUUID(); | node_info_.node_id_ = CommUtil::GenerateUUID(); | ||||
| node_info_.node_role_ = NodeRole::WORKER; | node_info_.node_role_ = NodeRole::WORKER; | ||||
| MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | ||||
| << ", the node id is:" << node_info_.node_id_; | << ", the node id is:" << node_info_.node_id_; | ||||
| InitClientToScheduler(); | |||||
| } | |||||
| void WorkerNode::InitClientToScheduler() { | |||||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||||
| uint16_t scheduler_port = ClusterConfig::scheduler_port(); | |||||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); | |||||
| client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||||
| switch (message.pb_meta().cmd()) { | |||||
| case NodeCommand::HEARTBEAT: | |||||
| ProcessHeartbeatResp(message); | |||||
| break; | |||||
| case NodeCommand::REGISTER: | |||||
| ProcessRegisterResp(message); | |||||
| break; | |||||
| case NodeCommand::FETCH_SERVER: | |||||
| ProcessFetchServersResp(message); | |||||
| break; | |||||
| case NodeCommand::FINISH: | |||||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||||
| } | |||||
| NotifyMessageArrival(message); | |||||
| }); | |||||
| client_to_scheduler_->Init(); | |||||
| worker_thread_ = std::make_unique<std::thread>([&]() { | |||||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||||
| client_to_scheduler_->Start(); | |||||
| }); | |||||
| worker_thread_->detach(); | |||||
| if (!InitClientToScheduler()) { | |||||
| MS_LOG(EXCEPTION) << "Worker node init client timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "Worker node init client successful!"; | |||||
| } | } | ||||
| bool WorkerNode::Stop() { | bool WorkerNode::Stop() { | ||||
| @@ -144,8 +85,8 @@ bool WorkerNode::Stop() { | |||||
| } | } | ||||
| } | } | ||||
| client_to_scheduler_->StopEventBase(); | client_to_scheduler_->StopEventBase(); | ||||
| if (worker_thread_->joinable()) { | |||||
| worker_thread_->join(); | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | } | ||||
| if (heart_beat_thread_->joinable()) { | if (heart_beat_thread_->joinable()) { | ||||
| heart_beat_thread_->join(); | heart_beat_thread_->join(); | ||||
| @@ -165,23 +106,6 @@ bool WorkerNode::Finish(const uint32_t &timeout) { | |||||
| is_already_finished_ = true; | is_already_finished_ = true; | ||||
| return Disconnect(client_to_scheduler_, timeout); | return Disconnect(client_to_scheduler_, timeout); | ||||
| } | } | ||||
| bool WorkerNode::BroadcastToServers(const std::string &message) { | |||||
| uint64_t request_id = ++next_request_id_; | |||||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||||
| message_meta.set_request_id(request_id); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(message); | |||||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||||
| client->SendMessage(comm_message); | |||||
| } | |||||
| return Wait(request_id); | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,50 +17,35 @@ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | #ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | ||||
| #define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | #define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | ||||
| #include <atomic> | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <thread> | |||||
| #include <unordered_map> | |||||
| #include <utility> | #include <utility> | ||||
| #include <condition_variable> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <tuple> | |||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "ps/core/tcp_client.h" | #include "ps/core/tcp_client.h" | ||||
| #include "ps/core/tcp_server.h" | #include "ps/core/tcp_server.h" | ||||
| #include "ps/core/node.h" | |||||
| #include "ps/core/abstract_node.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| class WorkerNode : public Node { | |||||
| class WorkerNode : public AbstractNode { | |||||
| public: | public: | ||||
| WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {} | |||||
| WorkerNode() = default; | |||||
| ~WorkerNode() override; | ~WorkerNode() override; | ||||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | ||||
| bool Stop() override; | bool Stop() override; | ||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | ||||
| bool BroadcastToServers(const std::string &message); | |||||
| private: | private: | ||||
| void Register(); | |||||
| void ProcessRegisterResp(const CommMessage &message); | |||||
| void Initialize(); | void Initialize(); | ||||
| void InitClientToScheduler(); | |||||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||||
| std::unique_ptr<std::thread> worker_thread_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/embedding_table_shard_metadata.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| uint64_t EmbeddingTableShardMetadata::begin() const { return begin_; } | |||||
| uint64_t EmbeddingTableShardMetadata::end() const { return end_; } | |||||
| uint64_t EmbeddingTableShardMetadata::size() const { return end_ - begin_; } | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ | |||||
| #define MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ | |||||
| #include <iostream> | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| class EmbeddingTableShardMetadata { | |||||
| public: | |||||
| explicit EmbeddingTableShardMetadata(uint64_t begin, uint64_t end) : begin_(begin), end_(end) {} | |||||
| virtual ~EmbeddingTableShardMetadata() = default; | |||||
| uint64_t begin() const; | |||||
| uint64_t end() const; | |||||
| uint64_t size() const; | |||||
| private: | |||||
| uint64_t begin_; | |||||
| uint64_t end_; | |||||
| }; | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_ | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common_test.h" | |||||
| #include "ps/embedding_table_shard_metadata.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| class TestEmbeddingTableShardMetadata : public UT::Common { | |||||
| public: | |||||
| TestEmbeddingTableShardMetadata() = default; | |||||
| virtual ~TestEmbeddingTableShardMetadata() = default; | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(TestEmbeddingTableShardMetadata, EmbeddingTable) { | |||||
| EmbeddingTableShardMetadata embedding_table_shard(1, 100); | |||||
| EXPECT_EQ(embedding_table_shard.begin(), 1); | |||||
| EXPECT_EQ(embedding_table_shard.end(), 100); | |||||
| EXPECT_EQ(embedding_table_shard.size(), 99); | |||||
| } | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||