| @@ -15,6 +15,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") | |||||
| endif () | endif () | ||||
| if (NOT ENABLE_D) | if (NOT ENABLE_D) | ||||
| @@ -94,16 +94,16 @@ std::string CommUtil::GenerateUUID() { | |||||
| ss << dis(gen); | ss << dis(gen); | ||||
| } | } | ||||
| ss << "-4"; | ss << "-4"; | ||||
| for (i = 0; i < kGroup2RandomLength - 1; i++) { | |||||
| for (i = 0; i < kGroup3RandomLength - 1; i++) { | |||||
| ss << dis(gen); | ss << dis(gen); | ||||
| } | } | ||||
| ss << "-"; | ss << "-"; | ||||
| ss << dis2(gen); | ss << dis2(gen); | ||||
| for (i = 0; i < kGroup3RandomLength - 1; i++) { | |||||
| for (i = 0; i < kGroup4RandomLength - 1; i++) { | |||||
| ss << dis(gen); | ss << dis(gen); | ||||
| } | } | ||||
| ss << "-"; | ss << "-"; | ||||
| for (i = 0; i < kGroup4RandomLength; i++) { | |||||
| for (i = 0; i < kGroup5RandomLength; i++) { | |||||
| ss << dis(gen); | ss << dis(gen); | ||||
| } | } | ||||
| return ss.str(); | return ss.str(); | ||||
| @@ -121,7 +121,14 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) { | |||||
| MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; | MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; | ||||
| } | } | ||||
| } | } | ||||
| bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) { | |||||
| if (node_role == NodeRole::SERVER && (rank_id > ClusterConfig::server_num() - 1)) { | |||||
| return false; | |||||
| } else if (node_role == NodeRole::WORKER && (rank_id > ClusterConfig::worker_num() - 1)) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -48,6 +48,7 @@ | |||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| #include "ps/core/cluster_config.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -66,6 +67,7 @@ class CommUtil { | |||||
| static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | ||||
| static std::string GenerateUUID(); | static std::string GenerateUUID(); | ||||
| static std::string NodeRoleToString(const NodeRole &role); | static std::string NodeRoleToString(const NodeRole &role); | ||||
| static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id); | |||||
| private: | private: | ||||
| static std::random_device rd; | static std::random_device rd; | ||||
| @@ -47,13 +47,17 @@ void Node::ProcessHeartbeatResp(const CommMessage &message) { | |||||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | is_ready_ = heartbeat_resp_message.is_cluster_ready(); | ||||
| if (is_ready_.load()) { | if (is_ready_.load()) { | ||||
| wait_start_cond_.notify_all(); | wait_start_cond_.notify_all(); | ||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; | |||||
| } | } | ||||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | is_finish_ = heartbeat_resp_message.is_cluster_finish(); | ||||
| if (is_finish_.load()) { | if (is_finish_.load()) { | ||||
| wait_finish_cond_.notify_all(); | wait_finish_cond_.notify_all(); | ||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; | |||||
| } | } | ||||
| is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | ||||
| if (is_timeout_ && on_node_event_message_) { | if (is_timeout_ && on_node_event_message_) { | ||||
| is_ready_ = true; | |||||
| wait_start_cond_.notify_all(); | |||||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | on_node_event_message_(NodeEvent::NODE_TIMEOUT); | ||||
| } | } | ||||
| } | } | ||||
| @@ -64,7 +68,9 @@ void Node::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||||
| CommMessage message; | CommMessage message; | ||||
| *message.mutable_pb_meta() = {meta}; | *message.mutable_pb_meta() = {meta}; | ||||
| SendMessageSync(client, message); | |||||
| if (!SendMessageSync(client, message)) { | |||||
| MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; | |||||
| } | |||||
| } | } | ||||
| void Node::ProcessFetchServersResp(const CommMessage &message) { | void Node::ProcessFetchServersResp(const CommMessage &message) { | ||||
| @@ -72,10 +78,10 @@ void Node::ProcessFetchServersResp(const CommMessage &message) { | |||||
| fetch_servers_resp_message.ParseFromString(message.data()); | fetch_servers_resp_message.ParseFromString(message.data()); | ||||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | for (const auto &it : fetch_servers_resp_message.servers_meta()) { | ||||
| server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port()); | |||||
| nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size(); | |||||
| MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); | |||||
| } | } | ||||
| std::string Node::node_id() const { return node_info_.node_id_; } | std::string Node::node_id() const { return node_info_.node_id_; } | ||||
| @@ -86,19 +92,128 @@ void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { | |||||
| on_node_event_message_ = on_node_event_message; | on_node_event_message_ = on_node_event_message; | ||||
| } | } | ||||
| void Node::Wait(uint64_t request_id) { | |||||
| std::unique_lock<std::mutex> lock(message_mutex_); | |||||
| message_tracker_cond_.wait(lock, [&] { | |||||
| bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { | |||||
| std::unique_lock<std::mutex> lock(message_tracker_mutex_); | |||||
| bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||||
| bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; | bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; | ||||
| if (ret) { | |||||
| MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id; | |||||
| message_tracker_.erase(request_id); | |||||
| } | |||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| message_tracker_.erase(request_id); | |||||
| return res; | |||||
| } | |||||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| const uint32_t &timeout) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||||
| } | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(message); | |||||
| auto client = GetOrCreateTcpClient(rank_id); | |||||
| return SendMessageSync(client, comm_message); | |||||
| } | } | ||||
| void Node::Disconnect(const std::shared_ptr<TcpClient> &client) { | |||||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||||
| const uint32_t &timeout) { | |||||
| uint64_t request_id = ++next_request_id_; | |||||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||||
| if (rank_ids.size() != data.size()) { | |||||
| MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; | |||||
| } | |||||
| for (size_t it = 0; it < rank_ids.size(); ++it) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||||
| } | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||||
| message_meta.set_request_id(request_id); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(data.at(it)); | |||||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||||
| client->SendMessage(comm_message); | |||||
| } | |||||
| return Wait(request_id, timeout); | |||||
| } | |||||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||||
| } | |||||
| uint64_t request_id = ++next_request_id_; | |||||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||||
| set_message_callback(request_id, [&]() { | |||||
| receive_messages_mutex_.lock(); | |||||
| auto res = receive_messages_[request_id]; | |||||
| comm_message_resp = &res[rank_id]; | |||||
| receive_messages_.erase(request_id); | |||||
| receive_messages_mutex_.unlock(); | |||||
| }); | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||||
| message_meta.set_request_id(request_id); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(message); | |||||
| auto client = GetOrCreateTcpClient(rank_id); | |||||
| client->SendMessage(comm_message); | |||||
| return Wait(request_id, timeout); | |||||
| } | |||||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||||
| std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) { | |||||
| uint64_t request_id = ++next_request_id_; | |||||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||||
| if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { | |||||
| MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; | |||||
| } | |||||
| size_t len = rank_ids.size(); | |||||
| set_message_callback(request_id, [&]() { | |||||
| receive_messages_mutex_.lock(); | |||||
| auto res = receive_messages_[request_id]; | |||||
| for (size_t it = 0; it < len; ++it) { | |||||
| comm_message_resp->at(it) = &res[rank_ids.at(it)]; | |||||
| } | |||||
| receive_messages_.erase(request_id); | |||||
| receive_messages_mutex_.unlock(); | |||||
| }); | |||||
| for (size_t it = 0; it < len; ++it) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||||
| } | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||||
| message_meta.set_request_id(request_id); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(data.at(it)); | |||||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||||
| client->SendMessage(comm_message); | |||||
| } | |||||
| return Wait(request_id, timeout); | |||||
| } | |||||
| bool Node::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { | |||||
| MessageMeta meta; | MessageMeta meta; | ||||
| meta.set_cmd(NodeCommand::FINISH); | meta.set_cmd(NodeCommand::FINISH); | ||||
| @@ -108,36 +223,43 @@ void Node::Disconnect(const std::shared_ptr<TcpClient> &client) { | |||||
| CommMessage message; | CommMessage message; | ||||
| *message.mutable_pb_meta() = {meta}; | *message.mutable_pb_meta() = {meta}; | ||||
| message.set_data(finish_message.SerializeAsString()); | message.set_data(finish_message.SerializeAsString()); | ||||
| SendMessageSync(client, message); | |||||
| WaitForDisconnect(); | |||||
| if (!SendMessageSync(client, message)) { | |||||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||||
| return WaitForDisconnect(timeout); | |||||
| } | } | ||||
| void Node::WaitForStart() { | |||||
| bool Node::WaitForStart(const uint32_t &timeout) { | |||||
| std::unique_lock<std::mutex> lock(wait_start_mutex_); | std::unique_lock<std::mutex> lock(wait_start_mutex_); | ||||
| wait_start_cond_.wait(lock, [&] { | |||||
| if (is_ready_.load()) { | |||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!"; | |||||
| bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||||
| bool res = is_ready_.load(); | |||||
| if (res) { | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!"; | |||||
| } | } | ||||
| return is_ready_.load(); | |||||
| return res; | |||||
| }); | }); | ||||
| return res; | |||||
| } | } | ||||
| void Node::WaitForDisconnect() { | |||||
| bool Node::WaitForDisconnect(const uint32_t &timeout) { | |||||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | std::unique_lock<std::mutex> lock(wait_finish_mutex_); | ||||
| wait_finish_cond_.wait(lock, [&] { | |||||
| bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||||
| if (is_finish_.load()) { | if (is_finish_.load()) { | ||||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||||
| } | } | ||||
| return is_finish_.load(); | return is_finish_.load(); | ||||
| }); | }); | ||||
| return res; | |||||
| } | } | ||||
| void Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||||
| bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||||
| const uint32_t &timeout) { | |||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| message_tracker_[request_id] = std::make_pair(1, 0); | message_tracker_[request_id] = std::make_pair(1, 0); | ||||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | ||||
| client->SendMessage(message); | client->SendMessage(message); | ||||
| Wait(request_id); | |||||
| return Wait(request_id, timeout); | |||||
| } | } | ||||
| void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | ||||
| @@ -147,12 +269,83 @@ void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const Comm | |||||
| } | } | ||||
| void Node::NotifyMessageArrival(const CommMessage &message) { | void Node::NotifyMessageArrival(const CommMessage &message) { | ||||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||||
| const MessageMeta &message_meta = message.pb_meta(); | const MessageMeta &message_meta = message.pb_meta(); | ||||
| uint64_t request_id = message_meta.request_id(); | uint64_t request_id = message_meta.request_id(); | ||||
| message_tracker_[request_id].second++; | message_tracker_[request_id].second++; | ||||
| message_tracker_cond_.notify_all(); | message_tracker_cond_.notify_all(); | ||||
| } | } | ||||
| const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) { | |||||
| std::lock_guard<std::mutex> lock(client_mutex_); | |||||
| if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { | |||||
| return connected_nodes_[rank_id]; | |||||
| } else { | |||||
| if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!"; | |||||
| } | |||||
| std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; | |||||
| uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; | |||||
| auto client = std::make_shared<TcpClient>(ip, port); | |||||
| client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||||
| switch (message.pb_meta().cmd()) { | |||||
| case NodeCommand::SEND_DATA: | |||||
| ProcessSendDataResp(message); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||||
| } | |||||
| NotifyMessageArrival(message); | |||||
| }); | |||||
| client->Init(); | |||||
| connected_nodes_[rank_id] = client; | |||||
| return connected_nodes_[rank_id]; | |||||
| } | |||||
| } | |||||
| void Node::ProcessSendDataResp(const CommMessage &message) { | |||||
| std::lock_guard<std::mutex> lock(receive_messages_mutex_); | |||||
| const MessageMeta &message_meta = message.pb_meta(); | |||||
| const uint32_t &rank_id = message_meta.rank_id(); | |||||
| const uint64_t request_id = message_meta.request_id(); | |||||
| auto it = receive_messages_.find(request_id); | |||||
| if (it != receive_messages_.end()) { | |||||
| it->second.insert(std::make_pair(rank_id, message)); | |||||
| } else { | |||||
| std::unordered_map<uint32_t, CommMessage> res; | |||||
| res.insert(std::make_pair(rank_id, message)); | |||||
| receive_messages_[request_id] = res; | |||||
| } | |||||
| RunMessageCallback(request_id); | |||||
| } | |||||
| void Node::RunMessageCallback(const uint64_t &request_id) { | |||||
| message_callbacks_mutex_.lock(); | |||||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) { | |||||
| auto it = message_callbacks_.find(request_id); | |||||
| if (it != message_callbacks_.end()) { | |||||
| message_callbacks_mutex_.unlock(); | |||||
| if (it->second) { | |||||
| it->second(); | |||||
| } | |||||
| message_callbacks_mutex_.lock(); | |||||
| message_callbacks_.erase(it); | |||||
| } | |||||
| } | |||||
| message_callbacks_mutex_.unlock(); | |||||
| } | |||||
| void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { | |||||
| if (!message_callback) { | |||||
| return; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | |||||
| message_callbacks_[request_id] = message_callback; | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,15 +21,15 @@ | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <functional> | #include <functional> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <set> | |||||
| #include <string> | #include <string> | ||||
| #include <thread> | #include <thread> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <condition_variable> | #include <condition_variable> | ||||
| #include <utility> | #include <utility> | ||||
| #include <tuple> | |||||
| #include <map> | |||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| @@ -42,6 +42,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| constexpr int kTimeoutInSeconds = 30; | |||||
| constexpr int kCommTimeoutInSeconds = 3; | |||||
| class Node { | class Node { | ||||
| public: | public: | ||||
| Node() | Node() | ||||
| @@ -49,51 +51,83 @@ class Node { | |||||
| is_finish_(false), | is_finish_(false), | ||||
| is_timeout_(false), | is_timeout_(false), | ||||
| is_already_stopped_(true), | is_already_stopped_(true), | ||||
| is_already_finished_(false), | |||||
| next_request_id_(0), | next_request_id_(0), | ||||
| heart_beat_thread_(nullptr) {} | heart_beat_thread_(nullptr) {} | ||||
| virtual ~Node() = default; | virtual ~Node() = default; | ||||
| using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | ||||
| void set_callback(const OnNodeEventMessage &on_node_event_message); | |||||
| using MessageCallback = std::function<void()>; | |||||
| virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0; | |||||
| virtual bool Stop() = 0; | |||||
| virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; | |||||
| void set_callback(const OnNodeEventMessage &on_node_event_message); | |||||
| std::string node_id() const; | std::string node_id() const; | ||||
| uint32_t rank_id() const; | uint32_t rank_id() const; | ||||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| void Wait(uint64_t request_id); | |||||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| protected: | protected: | ||||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | void Heartbeat(const std::shared_ptr<TcpClient> &client); | ||||
| void ProcessHeartbeatResp(const CommMessage &message); | void ProcessHeartbeatResp(const CommMessage &message); | ||||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | void FetchServers(const std::shared_ptr<TcpClient> &client); | ||||
| void ProcessFetchServersResp(const CommMessage &message); | void ProcessFetchServersResp(const CommMessage &message); | ||||
| void Disconnect(const std::shared_ptr<TcpClient> &client); | |||||
| void WaitForStart(); | |||||
| void WaitForDisconnect(); | |||||
| void SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||||
| bool WaitForStart(const uint32_t &timeout); | |||||
| bool WaitForDisconnect(const uint32_t &timeout); | |||||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | ||||
| void NotifyMessageArrival(const CommMessage &message); | void NotifyMessageArrival(const CommMessage &message); | ||||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | |||||
| void ProcessSendDataResp(const CommMessage &message); | |||||
| void RunMessageCallback(const uint64_t &request_id); | |||||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | |||||
| NodeInfo node_info_; | NodeInfo node_info_; | ||||
| std::atomic<bool> is_ready_; | std::atomic<bool> is_ready_; | ||||
| std::atomic<bool> is_finish_; | std::atomic<bool> is_finish_; | ||||
| std::atomic<bool> is_timeout_; | std::atomic<bool> is_timeout_; | ||||
| std::atomic<bool> is_already_stopped_; | std::atomic<bool> is_already_stopped_; | ||||
| std::atomic<bool> is_already_finished_; | |||||
| std::atomic_uint64_t next_request_id_; | std::atomic_uint64_t next_request_id_; | ||||
| std::unique_ptr<std::thread> heart_beat_thread_; | std::unique_ptr<std::thread> heart_beat_thread_; | ||||
| OnNodeEventMessage on_node_event_message_; | OnNodeEventMessage on_node_event_message_; | ||||
| // rank_id-><ip, port> | |||||
| std::unordered_map<int, std::pair<std::string, uint16_t>> server_rank_ids_; | |||||
| // <NodeRole,rank_id>-><ip, port> | |||||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | |||||
| // rank_id->tcpclient | |||||
| std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_; | |||||
| // timestamp-><expected responses, actual responses> | |||||
| // request_id-><expected responses, actual responses> | |||||
| std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | ||||
| std::mutex message_mutex_; | |||||
| std::mutex message_tracker_mutex_; | |||||
| std::condition_variable message_tracker_cond_; | std::condition_variable message_tracker_cond_; | ||||
| std::mutex wait_finish_mutex_; | std::mutex wait_finish_mutex_; | ||||
| std::condition_variable wait_finish_cond_; | std::condition_variable wait_finish_cond_; | ||||
| std::mutex wait_start_mutex_; | std::mutex wait_start_mutex_; | ||||
| std::condition_variable wait_start_cond_; | std::condition_variable wait_start_cond_; | ||||
| std::mutex finish_mutex_; | |||||
| std::mutex client_mutex_; | |||||
| // request_id -> <rank_id, CommMessage> | |||||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | |||||
| std::mutex receive_messages_mutex_; | |||||
| // request_id -> MessageCallback | |||||
| std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | |||||
| std::mutex message_callbacks_mutex_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -39,6 +39,10 @@ message MessageMeta { | |||||
| NodeCommand cmd = 1; | NodeCommand cmd = 1; | ||||
| // the request id of this message | // the request id of this message | ||||
| uint64 request_id = 2; | uint64 request_id = 2; | ||||
| // the role of the current node: worker,server,scheduler | |||||
| NodeRole role = 3; | |||||
| // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] | |||||
| int32 rank_id = 4; | |||||
| } | } | ||||
| message RegisterMessage { | message RegisterMessage { | ||||
| @@ -249,7 +249,7 @@ void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb | |||||
| void TcpClient::SendMessage(const CommMessage &message) const { | void TcpClient::SendMessage(const CommMessage &message) const { | ||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | MS_EXCEPTION_IF_NULL(buffer_event_); | ||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| size_t buf_size = message.ByteSizeLong(); | |||||
| std::vector<unsigned char> serialized(buf_size); | std::vector<unsigned char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | ||||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | ||||
| @@ -23,7 +23,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } | void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } | ||||
| void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | ||||
| @@ -32,11 +31,11 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| while (num > 0) { | while (num > 0) { | ||||
| if (remaining_length_ == 0) { | if (remaining_length_ == 0) { | ||||
| for (int i = 0; i < 4 && num > 0; ++i) { | |||||
| for (int i = 0; i < kHeaderLen && num > 0; ++i) { | |||||
| header_[++header_index_] = *(buffer_data + i); | header_[++header_index_] = *(buffer_data + i); | ||||
| --num; | --num; | ||||
| if (header_index_ == 3) { | |||||
| message_length_ = *reinterpret_cast<const uint32_t *>(header_); | |||||
| if (header_index_ == kHeaderLen - 1) { | |||||
| message_length_ = *reinterpret_cast<const size_t *>(header_); | |||||
| remaining_length_ = message_length_; | remaining_length_ = message_length_; | ||||
| message_buffer_.reset(new unsigned char[remaining_length_]); | message_buffer_.reset(new unsigned char[remaining_length_]); | ||||
| buffer_data += (i + 1); | buffer_data += (i + 1); | ||||
| @@ -46,7 +45,7 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| } | } | ||||
| if (remaining_length_ > 0 && num > 0) { | if (remaining_length_ > 0 && num > 0) { | ||||
| uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | |||||
| size_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | |||||
| remaining_length_ -= copy_len; | remaining_length_ -= copy_len; | ||||
| num -= copy_len; | num -= copy_len; | ||||
| @@ -71,7 +70,6 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,6 +31,7 @@ namespace mindspore { | |||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| using messageReceive = std::function<void(const CommMessage &message)>; | using messageReceive = std::function<void(const CommMessage &message)>; | ||||
| constexpr int kHeaderLen = 8; | |||||
| class TcpMessageHandler { | class TcpMessageHandler { | ||||
| public: | public: | ||||
| @@ -51,10 +52,10 @@ class TcpMessageHandler { | |||||
| bool is_parsed_; | bool is_parsed_; | ||||
| std::unique_ptr<unsigned char> message_buffer_; | std::unique_ptr<unsigned char> message_buffer_; | ||||
| size_t message_length_; | size_t message_length_; | ||||
| uint32_t remaining_length_; | |||||
| char header_[4]; | |||||
| size_t remaining_length_; | |||||
| char header_[8]; | |||||
| int header_index_; | int header_index_; | ||||
| uint32_t last_copy_len_; | |||||
| size_t last_copy_len_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -55,7 +55,7 @@ const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } | |||||
| void TcpConnection::SendMessage(const CommMessage &message) const { | void TcpConnection::SendMessage(const CommMessage &message) const { | ||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | MS_EXCEPTION_IF_NULL(buffer_event_); | ||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| size_t buf_size = message.ByteSizeLong(); | |||||
| std::vector<unsigned char> serialized(buf_size); | std::vector<unsigned char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | ||||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | ||||
| @@ -0,0 +1,187 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/core/worker_node.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| WorkerNode::~WorkerNode() { | |||||
| MS_LOG(INFO) << "Stop worker node!"; | |||||
| if (!is_already_stopped_.load()) { | |||||
| is_ready_ = true; | |||||
| is_timeout_ = true; | |||||
| client_to_scheduler_->Stop(); | |||||
| if (!connected_nodes_.empty()) { | |||||
| for (auto &connected_node : connected_nodes_) { | |||||
| connected_node.second->Stop(); | |||||
| } | |||||
| } | |||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (worker_thread_->joinable()) { | |||||
| worker_thread_->join(); | |||||
| } | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | |||||
| } | |||||
| } | |||||
| bool WorkerNode::Start(const uint32_t &timeout) { | |||||
| MS_LOG(INFO) << "Starting worker node!"; | |||||
| Initialize(); | |||||
| Register(); | |||||
| Heartbeat(client_to_scheduler_); | |||||
| if (!WaitForStart(timeout)) { | |||||
| MS_LOG(ERROR) << "Start Worker node timeout!"; | |||||
| return false; | |||||
| } | |||||
| MS_LOG(INFO) << "The node is ready to fetch servers!"; | |||||
| if (!is_timeout_.load()) { | |||||
| FetchServers(client_to_scheduler_); | |||||
| MS_LOG(INFO) << "Fetch servers successful!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The Worker node has successfully started."; | |||||
| return true; | |||||
| } | |||||
| void WorkerNode::Register() { | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::REGISTER); | |||||
| RegisterMessage register_message; | |||||
| register_message.set_node_id(node_info_.node_id_); | |||||
| register_message.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(register_message.SerializeAsString()); | |||||
| if (!SendMessageSync(client_to_scheduler_, comm_message)) { | |||||
| MS_LOG(EXCEPTION) << "Worker node register timeout!"; | |||||
| } | |||||
| MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_ | |||||
| << "is registering to scheduler, the request id is:" << message_meta.request_id(); | |||||
| } | |||||
| void WorkerNode::ProcessRegisterResp(const CommMessage &message) { | |||||
| RegisterRespMessage register_resp_message; | |||||
| register_resp_message.ParseFromString(message.data()); | |||||
| if (register_resp_message.node_id() != node_info_.node_id_) { | |||||
| MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() | |||||
| << " is not match the current node id:" << node_info_.node_id_; | |||||
| } | |||||
| node_info_.rank_id_ = register_resp_message.rank_id(); | |||||
| MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||||
| } | |||||
| void WorkerNode::Initialize() { | |||||
| is_already_stopped_ = false; | |||||
| node_info_.node_id_ = CommUtil::GenerateUUID(); | |||||
| node_info_.node_role_ = NodeRole::WORKER; | |||||
| MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_; | |||||
| InitClientToScheduler(); | |||||
| } | |||||
| void WorkerNode::InitClientToScheduler() { | |||||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||||
| uint16_t scheduler_port = ClusterConfig::scheduler_port(); | |||||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); | |||||
| client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||||
| switch (message.pb_meta().cmd()) { | |||||
| case NodeCommand::HEARTBEAT: | |||||
| ProcessHeartbeatResp(message); | |||||
| break; | |||||
| case NodeCommand::REGISTER: | |||||
| ProcessRegisterResp(message); | |||||
| break; | |||||
| case NodeCommand::FETCH_SERVER: | |||||
| ProcessFetchServersResp(message); | |||||
| break; | |||||
| case NodeCommand::FINISH: | |||||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||||
| } | |||||
| NotifyMessageArrival(message); | |||||
| }); | |||||
| client_to_scheduler_->Init(); | |||||
| worker_thread_ = std::make_unique<std::thread>([&]() { | |||||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||||
| client_to_scheduler_->Start(); | |||||
| }); | |||||
| worker_thread_->detach(); | |||||
| } | |||||
| bool WorkerNode::Stop() { | |||||
| MS_LOG(INFO) << "Stop worker node!"; | |||||
| if (!is_already_stopped_.load()) { | |||||
| is_ready_ = true; | |||||
| is_timeout_ = true; | |||||
| client_to_scheduler_->Stop(); | |||||
| if (!connected_nodes_.empty()) { | |||||
| for (auto &connected_node : connected_nodes_) { | |||||
| connected_node.second->Stop(); | |||||
| } | |||||
| } | |||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (worker_thread_->joinable()) { | |||||
| worker_thread_->join(); | |||||
| } | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool WorkerNode::Finish(const uint32_t &timeout) { | |||||
| std::lock_guard<std::mutex> lock(finish_mutex_); | |||||
| if (is_already_finished_) { | |||||
| MS_LOG(INFO) << "Worker node already finish!"; | |||||
| return true; | |||||
| } | |||||
| MS_LOG(INFO) << "Finish worker node!"; | |||||
| is_already_finished_ = true; | |||||
| return Disconnect(client_to_scheduler_, timeout); | |||||
| } | |||||
| bool WorkerNode::BroadcastToServers(const std::string &message) { | |||||
| uint64_t request_id = ++next_request_id_; | |||||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||||
| message_meta.set_request_id(request_id); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(message); | |||||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||||
| client->SendMessage(comm_message); | |||||
| } | |||||
| return Wait(request_id); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||||
| #include <atomic> | |||||
| #include <cstdlib> | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <thread> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <condition_variable> | |||||
| #include <algorithm> | |||||
| #include <tuple> | |||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/cluster_config.h" | |||||
| #include "ps/core/tcp_client.h" | |||||
| #include "ps/core/tcp_server.h" | |||||
| #include "ps/core/node.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class WorkerNode : public Node { | |||||
| public: | |||||
| WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {} | |||||
| ~WorkerNode() override; | |||||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||||
| bool Stop() override; | |||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||||
| bool BroadcastToServers(const std::string &message); | |||||
| private: | |||||
| void Register(); | |||||
| void ProcessRegisterResp(const CommMessage &message); | |||||
| void Initialize(); | |||||
| void InitClientToScheduler(); | |||||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||||
| std::unique_ptr<std::thread> worker_thread_; | |||||
| }; | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||||
| @@ -39,6 +39,14 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { | |||||
| EXPECT_TRUE(!interface.empty()); | EXPECT_TRUE(!interface.empty()); | ||||
| EXPECT_TRUE(!ip.empty()); | EXPECT_TRUE(!ip.empty()); | ||||
| } | } | ||||
| TEST_F(TestCommUtil, ValidateRankId) { | |||||
| ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999); | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); | |||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | |||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); | |||||
| } | |||||
| } // namespace comm | } // namespace comm | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,117 +33,118 @@ class TestTcpMessageHandler : public UT::Common { | |||||
| void TearDown() override {} | void TearDown() override {} | ||||
| }; | }; | ||||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | |||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | ||||
| std::string data(1000, 'a'); | std::string data(1000, 'a'); | ||||
| CommMessage message; | CommMessage message; | ||||
| message.set_data(data); | message.set_data(data); | ||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[1007]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| size_t buf_size = message.ByteSizeLong(); | |||||
| char result[1011]; | |||||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| std::vector<char> serialized(buf_size); | std::vector<char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | ||||
| memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| handler.ReceiveMessage(result, buf_size + 4); | |||||
| memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||||
| handler.ReceiveMessage(result, buf_size + kHeaderLen); | |||||
| } | } | ||||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | |||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | ||||
| std::string data(1000, 'a'); | std::string data(1000, 'a'); | ||||
| CommMessage message; | CommMessage message; | ||||
| message.set_data(data); | message.set_data(data); | ||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[2014]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| size_t buf_size = message.ByteSizeLong(); | |||||
| char result[2022] = {0}; | |||||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| std::vector<char> serialized(buf_size); | std::vector<char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | ||||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||||
| ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size); | |||||
| ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| handler.ReceiveMessage(result, 2 * buf_size + 4 * 2); | |||||
| handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2); | |||||
| } | } | ||||
| TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | |||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); }); | |||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); }); | |||||
| std::string data(4087, 'a'); | |||||
| std::string data(4081, 'a'); | |||||
| CommMessage message; | CommMessage message; | ||||
| message.set_data(data); | message.set_data(data); | ||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[4096]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| size_t buf_size = message.ByteSizeLong(); | |||||
| char result[4096] = {0}; | |||||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| std::vector<char> serialized(buf_size); | std::vector<char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | ||||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2); | |||||
| ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| handler.ReceiveMessage(result, 4096); | handler.ReceiveMessage(result, 4096); | ||||
| ret = memcpy_s(result, 2, &buf_size + 2, 2); | |||||
| auto temp = reinterpret_cast<char *>(&buf_size); | |||||
| ret = memcpy_s(result, 4, temp + 4, 4); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size); | |||||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| handler.ReceiveMessage(result, 4092); | |||||
| handler.ReceiveMessage(result, 4088); | |||||
| } | } | ||||
| TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { | |||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); }); | |||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); }); | |||||
| std::string data(4085, 'a'); | |||||
| std::string data(4077, 'a'); | |||||
| CommMessage message; | CommMessage message; | ||||
| message.set_data(data); | message.set_data(data); | ||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[4096]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| size_t buf_size = message.ByteSizeLong(); | |||||
| char result[4096] = {0}; | |||||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| std::vector<char> serialized(buf_size); | std::vector<char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | ||||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||||
| ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -155,9 +156,8 @@ TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| handler.ReceiveMessage(result, 4088); | |||||
| handler.ReceiveMessage(result, 4080); | |||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||