| @@ -15,6 +15,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") | |||
| endif () | |||
| if (NOT ENABLE_D) | |||
| @@ -94,16 +94,16 @@ std::string CommUtil::GenerateUUID() { | |||
| ss << dis(gen); | |||
| } | |||
| ss << "-4"; | |||
| for (i = 0; i < kGroup2RandomLength - 1; i++) { | |||
| for (i = 0; i < kGroup3RandomLength - 1; i++) { | |||
| ss << dis(gen); | |||
| } | |||
| ss << "-"; | |||
| ss << dis2(gen); | |||
| for (i = 0; i < kGroup3RandomLength - 1; i++) { | |||
| for (i = 0; i < kGroup4RandomLength - 1; i++) { | |||
| ss << dis(gen); | |||
| } | |||
| ss << "-"; | |||
| for (i = 0; i < kGroup4RandomLength; i++) { | |||
| for (i = 0; i < kGroup5RandomLength; i++) { | |||
| ss << dis(gen); | |||
| } | |||
| return ss.str(); | |||
| @@ -121,7 +121,14 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) { | |||
| MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; | |||
| } | |||
| } | |||
| bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) { | |||
| if (node_role == NodeRole::SERVER && (rank_id > ClusterConfig::server_num() - 1)) { | |||
| return false; | |||
| } else if (node_role == NodeRole::WORKER && (rank_id > ClusterConfig::worker_num() - 1)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -48,6 +48,7 @@ | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -66,6 +67,7 @@ class CommUtil { | |||
| static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | |||
| static std::string GenerateUUID(); | |||
| static std::string NodeRoleToString(const NodeRole &role); | |||
| static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id); | |||
| private: | |||
| static std::random_device rd; | |||
| @@ -47,13 +47,17 @@ void Node::ProcessHeartbeatResp(const CommMessage &message) { | |||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||
| if (is_ready_.load()) { | |||
| wait_start_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; | |||
| } | |||
| is_finish_ = heartbeat_resp_message.is_cluster_finish(); | |||
| if (is_finish_.load()) { | |||
| wait_finish_cond_.notify_all(); | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; | |||
| } | |||
| is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); | |||
| if (is_timeout_ && on_node_event_message_) { | |||
| is_ready_ = true; | |||
| wait_start_cond_.notify_all(); | |||
| on_node_event_message_(NodeEvent::NODE_TIMEOUT); | |||
| } | |||
| } | |||
| @@ -64,7 +68,9 @@ void Node::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| SendMessageSync(client, message); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; | |||
| } | |||
| } | |||
| void Node::ProcessFetchServersResp(const CommMessage &message) { | |||
| @@ -72,10 +78,10 @@ void Node::ProcessFetchServersResp(const CommMessage &message) { | |||
| fetch_servers_resp_message.ParseFromString(message.data()); | |||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | |||
| server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port()); | |||
| nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); | |||
| } | |||
| MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size(); | |||
| MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); | |||
| } | |||
| std::string Node::node_id() const { return node_info_.node_id_; } | |||
| @@ -86,19 +92,128 @@ void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { | |||
| on_node_event_message_ = on_node_event_message; | |||
| } | |||
| void Node::Wait(uint64_t request_id) { | |||
| std::unique_lock<std::mutex> lock(message_mutex_); | |||
| message_tracker_cond_.wait(lock, [&] { | |||
| bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(message_tracker_mutex_); | |||
| bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; | |||
| if (ret) { | |||
| MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id; | |||
| message_tracker_.erase(request_id); | |||
| } | |||
| return ret; | |||
| }); | |||
| message_tracker_.erase(request_id); | |||
| return res; | |||
| } | |||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| return SendMessageSync(client, comm_message); | |||
| } | |||
| void Node::Disconnect(const std::shared_ptr<TcpClient> &client) { | |||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| if (rank_ids.size() != data.size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; | |||
| } | |||
| for (size_t it = 0; it < rank_ids.size(); ++it) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(data.at(it)); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| comm_message_resp = &res[rank_id]; | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| client->SendMessage(comm_message); | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; | |||
| } | |||
| size_t len = rank_ids.size(); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| for (size_t it = 0; it < len; ++it) { | |||
| comm_message_resp->at(it) = &res[rank_ids.at(it)]; | |||
| } | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| for (size_t it = 0; it < len; ++it) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(data.at(it)); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool Node::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FINISH); | |||
| @@ -108,36 +223,43 @@ void Node::Disconnect(const std::shared_ptr<TcpClient> &client) { | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(finish_message.SerializeAsString()); | |||
| SendMessageSync(client, message); | |||
| WaitForDisconnect(); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||
| return WaitForDisconnect(timeout); | |||
| } | |||
| void Node::WaitForStart() { | |||
| bool Node::WaitForStart(const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(wait_start_mutex_); | |||
| wait_start_cond_.wait(lock, [&] { | |||
| if (is_ready_.load()) { | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!"; | |||
| bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| bool res = is_ready_.load(); | |||
| if (res) { | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!"; | |||
| } | |||
| return is_ready_.load(); | |||
| return res; | |||
| }); | |||
| return res; | |||
| } | |||
| void Node::WaitForDisconnect() { | |||
| bool Node::WaitForDisconnect(const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(wait_finish_mutex_); | |||
| wait_finish_cond_.wait(lock, [&] { | |||
| bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| if (is_finish_.load()) { | |||
| MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; | |||
| } | |||
| return is_finish_.load(); | |||
| }); | |||
| return res; | |||
| } | |||
| void Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| Wait(request_id); | |||
| return Wait(request_id, timeout); | |||
| } | |||
| void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| @@ -147,12 +269,83 @@ void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const Comm | |||
| } | |||
| void Node::NotifyMessageArrival(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| uint64_t request_id = message_meta.request_id(); | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) { | |||
| std::lock_guard<std::mutex> lock(client_mutex_); | |||
| if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { | |||
| return connected_nodes_[rank_id]; | |||
| } else { | |||
| if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { | |||
| MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!"; | |||
| } | |||
| std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; | |||
| uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; | |||
| auto client = std::make_shared<TcpClient>(ip, port); | |||
| client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendDataResp(message); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| NotifyMessageArrival(message); | |||
| }); | |||
| client->Init(); | |||
| connected_nodes_[rank_id] = client; | |||
| return connected_nodes_[rank_id]; | |||
| } | |||
| } | |||
| void Node::ProcessSendDataResp(const CommMessage &message) { | |||
| std::lock_guard<std::mutex> lock(receive_messages_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| const uint32_t &rank_id = message_meta.rank_id(); | |||
| const uint64_t request_id = message_meta.request_id(); | |||
| auto it = receive_messages_.find(request_id); | |||
| if (it != receive_messages_.end()) { | |||
| it->second.insert(std::make_pair(rank_id, message)); | |||
| } else { | |||
| std::unordered_map<uint32_t, CommMessage> res; | |||
| res.insert(std::make_pair(rank_id, message)); | |||
| receive_messages_[request_id] = res; | |||
| } | |||
| RunMessageCallback(request_id); | |||
| } | |||
| void Node::RunMessageCallback(const uint64_t &request_id) { | |||
| message_callbacks_mutex_.lock(); | |||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) { | |||
| auto it = message_callbacks_.find(request_id); | |||
| if (it != message_callbacks_.end()) { | |||
| message_callbacks_mutex_.unlock(); | |||
| if (it->second) { | |||
| it->second(); | |||
| } | |||
| message_callbacks_mutex_.lock(); | |||
| message_callbacks_.erase(it); | |||
| } | |||
| } | |||
| message_callbacks_mutex_.unlock(); | |||
| } | |||
| void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { | |||
| if (!message_callback) { | |||
| return; | |||
| } | |||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | |||
| message_callbacks_[request_id] = message_callback; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -21,15 +21,15 @@ | |||
| #include <cstdlib> | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <condition_variable> | |||
| #include <utility> | |||
| #include <tuple> | |||
| #include <map> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| @@ -42,6 +42,8 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| constexpr int kTimeoutInSeconds = 30; | |||
| constexpr int kCommTimeoutInSeconds = 3; | |||
| class Node { | |||
| public: | |||
| Node() | |||
| @@ -49,51 +51,83 @@ class Node { | |||
| is_finish_(false), | |||
| is_timeout_(false), | |||
| is_already_stopped_(true), | |||
| is_already_finished_(false), | |||
| next_request_id_(0), | |||
| heart_beat_thread_(nullptr) {} | |||
| virtual ~Node() = default; | |||
| using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | |||
| void set_callback(const OnNodeEventMessage &on_node_event_message); | |||
| using MessageCallback = std::function<void()>; | |||
| virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0; | |||
| virtual bool Stop() = 0; | |||
| virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; | |||
| void set_callback(const OnNodeEventMessage &on_node_event_message); | |||
| std::string node_id() const; | |||
| uint32_t rank_id() const; | |||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void Wait(uint64_t request_id); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| protected: | |||
| void Heartbeat(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessHeartbeatResp(const CommMessage &message); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessFetchServersResp(const CommMessage &message); | |||
| void Disconnect(const std::shared_ptr<TcpClient> &client); | |||
| void WaitForStart(); | |||
| void WaitForDisconnect(); | |||
| void SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||
| bool WaitForStart(const uint32_t &timeout); | |||
| bool WaitForDisconnect(const uint32_t &timeout); | |||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | |||
| void ProcessSendDataResp(const CommMessage &message); | |||
| void RunMessageCallback(const uint64_t &request_id); | |||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | |||
| NodeInfo node_info_; | |||
| std::atomic<bool> is_ready_; | |||
| std::atomic<bool> is_finish_; | |||
| std::atomic<bool> is_timeout_; | |||
| std::atomic<bool> is_already_stopped_; | |||
| std::atomic<bool> is_already_finished_; | |||
| std::atomic_uint64_t next_request_id_; | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| OnNodeEventMessage on_node_event_message_; | |||
| // rank_id-><ip, port> | |||
| std::unordered_map<int, std::pair<std::string, uint16_t>> server_rank_ids_; | |||
| // <NodeRole,rank_id>-><ip, port> | |||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | |||
| // rank_id->tcpclient | |||
| std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_; | |||
| // timestamp-><expected responses, actual responses> | |||
| // request_id-><expected responses, actual responses> | |||
| std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | |||
| std::mutex message_mutex_; | |||
| std::mutex message_tracker_mutex_; | |||
| std::condition_variable message_tracker_cond_; | |||
| std::mutex wait_finish_mutex_; | |||
| std::condition_variable wait_finish_cond_; | |||
| std::mutex wait_start_mutex_; | |||
| std::condition_variable wait_start_cond_; | |||
| std::mutex finish_mutex_; | |||
| std::mutex client_mutex_; | |||
| // request_id -> <rank_id, CommMessage> | |||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | |||
| std::mutex receive_messages_mutex_; | |||
| // request_id -> MessageCallback | |||
| std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | |||
| std::mutex message_callbacks_mutex_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -39,6 +39,10 @@ message MessageMeta { | |||
| NodeCommand cmd = 1; | |||
| // the request id of this message | |||
| uint64 request_id = 2; | |||
| // the role of the current node: worker,server,scheduler | |||
| NodeRole role = 3; | |||
| // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] | |||
| int32 rank_id = 4; | |||
| } | |||
| message RegisterMessage { | |||
| @@ -249,7 +249,7 @@ void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb | |||
| void TcpClient::SendMessage(const CommMessage &message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | |||
| @@ -23,7 +23,6 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } | |||
| void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| @@ -32,11 +31,11 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| while (num > 0) { | |||
| if (remaining_length_ == 0) { | |||
| for (int i = 0; i < 4 && num > 0; ++i) { | |||
| for (int i = 0; i < kHeaderLen && num > 0; ++i) { | |||
| header_[++header_index_] = *(buffer_data + i); | |||
| --num; | |||
| if (header_index_ == 3) { | |||
| message_length_ = *reinterpret_cast<const uint32_t *>(header_); | |||
| if (header_index_ == kHeaderLen - 1) { | |||
| message_length_ = *reinterpret_cast<const size_t *>(header_); | |||
| remaining_length_ = message_length_; | |||
| message_buffer_.reset(new unsigned char[remaining_length_]); | |||
| buffer_data += (i + 1); | |||
| @@ -46,7 +45,7 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| } | |||
| if (remaining_length_ > 0 && num > 0) { | |||
| uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | |||
| size_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | |||
| remaining_length_ -= copy_len; | |||
| num -= copy_len; | |||
| @@ -71,7 +70,6 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| } | |||
| } | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -31,6 +31,7 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| using messageReceive = std::function<void(const CommMessage &message)>; | |||
| constexpr int kHeaderLen = 8; | |||
| class TcpMessageHandler { | |||
| public: | |||
| @@ -51,10 +52,10 @@ class TcpMessageHandler { | |||
| bool is_parsed_; | |||
| std::unique_ptr<unsigned char> message_buffer_; | |||
| size_t message_length_; | |||
| uint32_t remaining_length_; | |||
| char header_[4]; | |||
| size_t remaining_length_; | |||
| char header_[8]; | |||
| int header_index_; | |||
| uint32_t last_copy_len_; | |||
| size_t last_copy_len_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -55,7 +55,7 @@ const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } | |||
| void TcpConnection::SendMessage(const CommMessage &message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | |||
| @@ -0,0 +1,187 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/worker_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| WorkerNode::~WorkerNode() { | |||
| MS_LOG(INFO) << "Stop worker node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| is_ready_ = true; | |||
| is_timeout_ = true; | |||
| client_to_scheduler_->Stop(); | |||
| if (!connected_nodes_.empty()) { | |||
| for (auto &connected_node : connected_nodes_) { | |||
| connected_node.second->Stop(); | |||
| } | |||
| } | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (worker_thread_->joinable()) { | |||
| worker_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| } | |||
| } | |||
| bool WorkerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Starting worker node!"; | |||
| Initialize(); | |||
| Register(); | |||
| Heartbeat(client_to_scheduler_); | |||
| if (!WaitForStart(timeout)) { | |||
| MS_LOG(ERROR) << "Start Worker node timeout!"; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "The node is ready to fetch servers!"; | |||
| if (!is_timeout_.load()) { | |||
| FetchServers(client_to_scheduler_); | |||
| MS_LOG(INFO) << "Fetch servers successful!"; | |||
| } | |||
| MS_LOG(INFO) << "The Worker node has successfully started."; | |||
| return true; | |||
| } | |||
| void WorkerNode::Register() { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::REGISTER); | |||
| RegisterMessage register_message; | |||
| register_message.set_node_id(node_info_.node_id_); | |||
| register_message.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(register_message.SerializeAsString()); | |||
| if (!SendMessageSync(client_to_scheduler_, comm_message)) { | |||
| MS_LOG(EXCEPTION) << "Worker node register timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_ | |||
| << "is registering to scheduler, the request id is:" << message_meta.request_id(); | |||
| } | |||
| void WorkerNode::ProcessRegisterResp(const CommMessage &message) { | |||
| RegisterRespMessage register_resp_message; | |||
| register_resp_message.ParseFromString(message.data()); | |||
| if (register_resp_message.node_id() != node_info_.node_id_) { | |||
| MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() | |||
| << " is not match the current node id:" << node_info_.node_id_; | |||
| } | |||
| node_info_.rank_id_ = register_resp_message.rank_id(); | |||
| MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||
| } | |||
| void WorkerNode::Initialize() { | |||
| is_already_stopped_ = false; | |||
| node_info_.node_id_ = CommUtil::GenerateUUID(); | |||
| node_info_.node_role_ = NodeRole::WORKER; | |||
| MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_; | |||
| InitClientToScheduler(); | |||
| } | |||
| void WorkerNode::InitClientToScheduler() { | |||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||
| uint16_t scheduler_port = ClusterConfig::scheduler_port(); | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); | |||
| client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| case NodeCommand::HEARTBEAT: | |||
| ProcessHeartbeatResp(message); | |||
| break; | |||
| case NodeCommand::REGISTER: | |||
| ProcessRegisterResp(message); | |||
| break; | |||
| case NodeCommand::FETCH_SERVER: | |||
| ProcessFetchServersResp(message); | |||
| break; | |||
| case NodeCommand::FINISH: | |||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| NotifyMessageArrival(message); | |||
| }); | |||
| client_to_scheduler_->Init(); | |||
| worker_thread_ = std::make_unique<std::thread>([&]() { | |||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||
| client_to_scheduler_->Start(); | |||
| }); | |||
| worker_thread_->detach(); | |||
| } | |||
| bool WorkerNode::Stop() { | |||
| MS_LOG(INFO) << "Stop worker node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| is_ready_ = true; | |||
| is_timeout_ = true; | |||
| client_to_scheduler_->Stop(); | |||
| if (!connected_nodes_.empty()) { | |||
| for (auto &connected_node : connected_nodes_) { | |||
| connected_node.second->Stop(); | |||
| } | |||
| } | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (worker_thread_->joinable()) { | |||
| worker_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| } | |||
| return true; | |||
| } | |||
| bool WorkerNode::Finish(const uint32_t &timeout) { | |||
| std::lock_guard<std::mutex> lock(finish_mutex_); | |||
| if (is_already_finished_) { | |||
| MS_LOG(INFO) << "Worker node already finish!"; | |||
| return true; | |||
| } | |||
| MS_LOG(INFO) << "Finish worker node!"; | |||
| is_already_finished_ = true; | |||
| return Disconnect(client_to_scheduler_, timeout); | |||
| } | |||
| bool WorkerNode::BroadcastToServers(const std::string &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| return Wait(request_id); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||
| #include <atomic> | |||
| #include <cstdlib> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <condition_variable> | |||
| #include <algorithm> | |||
| #include <tuple> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| #include "ps/core/node.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class WorkerNode : public Node { | |||
| public: | |||
| WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {} | |||
| ~WorkerNode() override; | |||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| bool Stop() override; | |||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| bool BroadcastToServers(const std::string &message); | |||
| private: | |||
| void Register(); | |||
| void ProcessRegisterResp(const CommMessage &message); | |||
| void Initialize(); | |||
| void InitClientToScheduler(); | |||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||
| std::unique_ptr<std::thread> worker_thread_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ | |||
| @@ -39,6 +39,14 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { | |||
| EXPECT_TRUE(!interface.empty()); | |||
| EXPECT_TRUE(!ip.empty()); | |||
| } | |||
| TEST_F(TestCommUtil, ValidateRankId) { | |||
| ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999); | |||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); | |||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); | |||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | |||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -33,117 +33,118 @@ class TestTcpMessageHandler : public UT::Common { | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[1007]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[1011]; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| handler.ReceiveMessage(result, buf_size + 4); | |||
| memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| handler.ReceiveMessage(result, buf_size + kHeaderLen); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[2014]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[2022] = {0}; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||
| ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size); | |||
| ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 2 * buf_size + 4 * 2); | |||
| handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); }); | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); }); | |||
| std::string data(4087, 'a'); | |||
| std::string data(4081, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[4096]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[4096] = {0}; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2); | |||
| ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 4096); | |||
| ret = memcpy_s(result, 2, &buf_size + 2, 2); | |||
| auto temp = reinterpret_cast<char *>(&buf_size); | |||
| ret = memcpy_s(result, 4, temp + 4, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size); | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 4092); | |||
| handler.ReceiveMessage(result, 4088); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); }); | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); }); | |||
| std::string data(4085, 'a'); | |||
| std::string data(4077, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[4096]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[4096] = {0}; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||
| ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| @@ -155,9 +156,8 @@ TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 4088); | |||
| handler.ReceiveMessage(result, 4080); | |||
| } | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||