| @@ -20,8 +20,9 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::REGISTER); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::REGISTER); | |||
| RegisterMessage register_message; | |||
| register_message.set_node_id(node_info_.node_id_); | |||
| @@ -29,11 +30,8 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| register_message.set_ip(node_info_.ip_); | |||
| register_message.set_port(node_info_.port_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(register_message.SerializeAsString()); | |||
| comm_message.set_user_cmd(""); | |||
| if (!SendMessageSync(client, comm_message)) { | |||
| if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(), | |||
| register_message.ByteSizeLong())) { | |||
| MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " register timeout!"; | |||
| } | |||
| @@ -42,9 +40,11 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| << " the node id:" << node_info_.node_id_ << "is registering to scheduler!"; | |||
| } | |||
| void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | |||
| void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| RegisterRespMessage register_resp_message; | |||
| register_resp_message.ParseFromString(message.data()); | |||
| register_resp_message.ParseFromArray(data, size); | |||
| if (register_resp_message.node_id() != node_info_.node_id_) { | |||
| MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() | |||
| << " is not match the current node id:" << node_info_.node_id_; | |||
| @@ -52,28 +52,29 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | |||
| node_info_.rank_id_ = register_resp_message.rank_id(); | |||
| MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||
| MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_ | |||
| << " registered scheduler success!"; | |||
| } | |||
| bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) { | |||
| bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, | |||
| const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| if (node_role != NodeRole::SERVER) { | |||
| MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; | |||
| } | |||
| CommMessage &comm_message = const_cast<CommMessage &>(message); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||
| uint64_t request_id = AddMessageTrack(nodes_address_.size()); | |||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta->set_request_id(request_id); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| message_meta->set_role(node_info_.node_role_); | |||
| message_meta->set_user_cmd(command); | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||
| client->SendMessage(comm_message); | |||
| client->SendMessage(message_meta, Protos::RAW, message.get(), size); | |||
| } | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| @@ -84,28 +85,27 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me | |||
| on_node_event_message_ = on_node_event_message; | |||
| } | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, | |||
| const uint32_t &timeout) { | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, | |||
| int command, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| CommMessage &comm_message = const_cast<CommMessage &>(message); | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| message_meta->set_role(node_info_.node_role_); | |||
| message_meta->set_user_cmd(command); | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| return SendMessageSync(client, comm_message, timeout); | |||
| return SendMessageSync(client, message_meta, Protos::RAW, data.get(), len, timeout); | |||
| } | |||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<CommMessage> &data, const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = AddMessageTrack(data.size()); | |||
| if (rank_ids.size() != data.size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; | |||
| @@ -115,34 +115,32 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage &comm_message = const_cast<CommMessage &>(data.at(it)); | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta->set_request_id(request_id); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| message_meta->set_role(node_info_.node_role_); | |||
| message_meta->set_user_cmd(command); | |||
| auto send = data.at(it); | |||
| auto len = lens.at(it); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| client->SendMessage(message_meta, Protos::RAW, send.get(), len); | |||
| } | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, | |||
| CommMessage *output, const uint32_t &timeout) { | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, | |||
| int command, VectorPtr *output, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| CommMessage &comm_message = const_cast<CommMessage &>(message); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| uint64_t request_id = AddMessageTrack(1); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| @@ -151,59 +149,59 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta->set_request_id(request_id); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| message_meta->set_role(node_info_.node_role_); | |||
| message_meta->set_user_cmd(command); | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| client->SendMessage(comm_message); | |||
| client->SendMessage(message_meta, Protos::RAW, message.get(), len); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<CommMessage> &data, std::vector<CommMessage> *output, | |||
| const uint32_t &timeout) { | |||
| const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command, | |||
| std::vector<VectorPtr> *output, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| uint64_t request_id = AddMessageTrack(data.size()); | |||
| if (rank_ids.size() != data.size()) { | |||
| MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; | |||
| } | |||
| size_t len = rank_ids.size(); | |||
| size_t size = rank_ids.size(); | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| for (size_t it = 0; it < len; ++it) { | |||
| for (size_t it = 0; it < size; ++it) { | |||
| (*output).push_back(res[rank_ids.at(it)]); | |||
| } | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| for (size_t it = 0; it < len; ++it) { | |||
| for (size_t it = 0; it < size; ++it) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta->set_request_id(request_id); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| message_meta->set_role(node_info_.node_role_); | |||
| message_meta->set_user_cmd(command); | |||
| CommMessage &comm_message = const_cast<CommMessage &>(data.at(it)); | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| auto send = data.at(it); | |||
| auto len = data_lens.at(it); | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| client->SendMessage(message_meta, Protos::RAW, send.get(), len); | |||
| } | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| @@ -220,55 +218,61 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { | |||
| return res; | |||
| } | |||
| uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| const CommMessage &message) { | |||
| uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, | |||
| size_t size) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| CommMessage &comm_message = const_cast<CommMessage &>(message); | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| message_meta->set_role(node_info_.node_role_); | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| return SendMessageAsync(client, comm_message); | |||
| return SendMessageAsync(client, message_meta, Protos::RAW, data, size); | |||
| } | |||
| std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, | |||
| const uint32_t &rank_id, CommMessage *output) { | |||
| const uint32_t &rank_id, void **output, | |||
| size_t *size) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| MS_EXCEPTION_IF_NULL(size); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| receive_callbacks_mutex_.lock(); | |||
| uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); | |||
| receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = false; | |||
| if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { | |||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||
| auto res = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||
| *output = res->data(); | |||
| *size = res->size(); | |||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||
| receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true; | |||
| MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id; | |||
| } else { | |||
| set_receive_callback(rank_id, rank_request_id, [=]() { | |||
| receive_callbacks_[std::make_pair(rank_id, rank_request_id)] = [=]() mutable { | |||
| receive_callbacks_mutex_.lock(); | |||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||
| auto res = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||
| *output = res->data(); | |||
| *size = res->size(); | |||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||
| receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true; | |||
| MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id; | |||
| receive_callbacks_mutex_.unlock(); | |||
| }); | |||
| }; | |||
| } | |||
| receive_callbacks_mutex_.unlock(); | |||
| return std::make_pair(rank_id, rank_request_id); | |||
| } | |||
| bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(receive_callbacks_mutex_); | |||
| bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| if (actual_rank_request_ids_.count(request_id.first) && | |||
| (actual_rank_request_ids_[request_id.first] >= request_id.second)) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| }); | |||
| bool res = | |||
| receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; }); | |||
| return res; | |||
| } | |||
| @@ -297,17 +301,15 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||
| } | |||
| bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||
| auto meta = std::make_shared<MessageMeta>(); | |||
| meta->set_cmd(NodeCommand::HEARTBEAT); | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.set_node_id(node_info_.node_id_); | |||
| heartbeat_message.set_is_node_finish(is_node_finish); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(heartbeat_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), | |||
| heartbeat_message.ByteSizeLong())) { | |||
| MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| } | |||
| return true; | |||
| @@ -331,9 +333,11 @@ bool AbstractNode::CheckSchedulerTimeout() const { | |||
| return false; | |||
| } | |||
| void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||
| void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| HeartbeatRespMessage heartbeat_resp_message; | |||
| heartbeat_resp_message.ParseFromString(message.data()); | |||
| heartbeat_resp_message.ParseFromArray(data, size); | |||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||
| if (is_ready_.load()) { | |||
| @@ -359,19 +363,22 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||
| } | |||
| void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FETCH_SERVER); | |||
| auto meta = std::make_shared<MessageMeta>(); | |||
| meta->set_cmd(NodeCommand::FETCH_SERVER); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| if (!SendMessageSync(client, message)) { | |||
| FetchServersMessage fetch_servers; | |||
| fetch_servers.set_node_id(node_info_.node_id_); | |||
| if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(), | |||
| fetch_servers.ByteSizeLong())) { | |||
| MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; | |||
| } | |||
| } | |||
| void AbstractNode::ProcessFetchServersResp(const CommMessage &message) { | |||
| void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| FetchServersRespMessage fetch_servers_resp_message; | |||
| fetch_servers_resp_message.ParseFromString(message.data()); | |||
| fetch_servers_resp_message.ParseFromArray(data, size); | |||
| for (const auto &it : fetch_servers_resp_message.servers_meta()) { | |||
| nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); | |||
| @@ -381,16 +388,14 @@ void AbstractNode::ProcessFetchServersResp(const CommMessage &message) { | |||
| } | |||
| bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::FINISH); | |||
| auto meta = std::make_shared<MessageMeta>(); | |||
| meta->set_cmd(NodeCommand::FINISH); | |||
| FinishMessage finish_message; | |||
| finish_message.set_node_id(node_info_.node_id_); | |||
| CommMessage message; | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(finish_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| if (!SendMessageSync(client, meta, Protos::PROTOBUF, finish_message.SerializeAsString().data(), | |||
| finish_message.ByteSizeLong())) { | |||
| MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!"; | |||
| } | |||
| @@ -412,16 +417,17 @@ bool AbstractNode::InitClientToScheduler() { | |||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||
| uint16_t scheduler_port = ClusterConfig::scheduler_port(); | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); | |||
| client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| if (handlers_.count(message.pb_meta().cmd()) == 0) { | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| if (handlers_[message.pb_meta().cmd()] != nullptr) { | |||
| const auto &handler_ptr = handlers_[message.pb_meta().cmd()]; | |||
| (this->*handler_ptr)(message); | |||
| } | |||
| NotifyMessageArrival(message); | |||
| }); | |||
| client_to_scheduler_->SetMessageCallback( | |||
| [&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) { | |||
| if (handlers_.count(meta->cmd()) == 0) { | |||
| MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; | |||
| } | |||
| if (handlers_[meta->cmd()] != nullptr) { | |||
| const auto &handler_ptr = handlers_[meta->cmd()]; | |||
| (this->*handler_ptr)(meta, data, size); | |||
| } | |||
| NotifyMessageArrival(meta); | |||
| }); | |||
| client_to_scheduler_->Init(); | |||
| client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() { | |||
| @@ -447,19 +453,20 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int & | |||
| std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; | |||
| uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; | |||
| auto client = std::make_shared<TcpClient>(ip, port); | |||
| client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, | |||
| size_t size) { | |||
| switch (meta->cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendDataResp(message); | |||
| RunMessageCallback(message.pb_meta().request_id()); | |||
| ProcessSendDataResp(meta, protos, data, size); | |||
| RunMessageCallback(meta->request_id()); | |||
| break; | |||
| case NodeCommand::COLLECTIVE_SEND_DATA: | |||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; | |||
| MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; | |||
| } | |||
| NotifyMessageArrival(message); | |||
| NotifyMessageArrival(meta); | |||
| }); | |||
| client->Init(); | |||
| connected_nodes_[rank_id] = client; | |||
| @@ -469,8 +476,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int & | |||
| bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| uint64_t request_id = AddMessageTrack(1); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| @@ -478,29 +484,55 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con | |||
| return Wait(request_id, timeout); | |||
| } | |||
| uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| uint64_t request_id = AddMessageTrack(1); | |||
| meta->set_request_id(request_id); | |||
| client->SendMessage(meta, protos, data, size); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return request_id; | |||
| } | |||
| void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | |||
| bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| uint64_t request_id = AddMessageTrack(1); | |||
| meta->set_request_id(request_id); | |||
| client->SendMessage(meta, protos, data, size); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| bool res = Wait(request_id, timeout); | |||
| return res; | |||
| } | |||
| void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, | |||
| size_t size) { | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| std::lock_guard<std::mutex> lock(receive_messages_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| const uint32_t &rank_id = message_meta.rank_id(); | |||
| const uint64_t request_id = message_meta.request_id(); | |||
| const uint32_t &rank_id = meta->rank_id(); | |||
| const uint64_t request_id = meta->request_id(); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| auto it = receive_messages_.find(request_id); | |||
| VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0); | |||
| if (size > 0) { | |||
| int ret = memcpy_s(received_data.get()->data(), size, data, size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| } | |||
| if (it != receive_messages_.end()) { | |||
| it->second[rank_id] = message; | |||
| it->second[rank_id] = received_data; | |||
| } else { | |||
| std::unordered_map<uint32_t, CommMessage> res; | |||
| res.insert(std::make_pair(rank_id, message)); | |||
| std::unordered_map<uint32_t, VectorPtr> res; | |||
| res.insert(std::make_pair(rank_id, received_data)); | |||
| receive_messages_[request_id] = res; | |||
| } | |||
| } | |||
| @@ -509,7 +541,7 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) { | |||
| message_callbacks_mutex_.lock(); | |||
| // When receiving a message's response, Then compare with the desired number of responses, | |||
| // If they are equal, then call the callback function | |||
| if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { | |||
| if (CheckMessageTrack(request_id)) { | |||
| auto it = message_callbacks_.find(request_id); | |||
| if (it != message_callbacks_.end()) { | |||
| message_callbacks_mutex_.unlock(); | |||
| @@ -533,31 +565,31 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag | |||
| message_callbacks_[request_id] = callback; | |||
| } | |||
| void AbstractNode::NotifyMessageArrival(const CommMessage &message) { | |||
| void AbstractNode::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| uint64_t request_id = message_meta.request_id(); | |||
| uint64_t request_id = meta->request_id(); | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, | |||
| const MessageCallback &callback) { | |||
| if (!callback) { | |||
| return; | |||
| } | |||
| std::lock_guard<std::mutex> lock(receive_callbacks_mutex_); | |||
| receive_callbacks_[std::make_pair(rank_id, request_id)] = callback; | |||
| } | |||
| void AbstractNode::RunReceiveCallback(const CommMessage &message) { | |||
| void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, | |||
| size_t size) { | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| receive_callbacks_mutex_.lock(); | |||
| uint32_t rank_id = message.pb_meta().rank_id(); | |||
| uint32_t rank_id = meta->rank_id(); | |||
| // When receiving a collective message, Then generate rank request id,compare with the desired rank request id, | |||
| // If they are equal, then call the callback function | |||
| uint64_t rank_request_id = NextActualRankRequestId(rank_id); | |||
| received_data_[std::make_pair(rank_id, rank_request_id)] = message; | |||
| std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0); | |||
| int ret = memcpy_s(received_data->data(), size, data, size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| received_data_[std::make_pair(rank_id, rank_request_id)] = received_data; | |||
| MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id | |||
| << ", the send request id is:" << meta->request_id(); | |||
| auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id)); | |||
| if (it != receive_callbacks_.end()) { | |||
| receive_callbacks_mutex_.unlock(); | |||
| @@ -603,6 +635,18 @@ void AbstractNode::InitCommandHandler() { | |||
| handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp; | |||
| handlers_[NodeCommand::FINISH] = nullptr; | |||
| } | |||
| uint64_t AbstractNode::AddMessageTrack(const uint32_t &expected_response) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(expected_response, 0); | |||
| return request_id; | |||
| } | |||
| bool AbstractNode::CheckMessageTrack(const uint64_t &request_id) { | |||
| std::lock_guard<std::mutex> lock(message_tracker_mutex_); | |||
| return message_tracker_[request_id].first == message_tracker_[request_id].second + 1; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,7 @@ | |||
| #include <unordered_map> | |||
| #include "ps/core/node.h" | |||
| #include "ps/core/message.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -34,53 +35,63 @@ class AbstractNode : public Node { | |||
| AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} | |||
| ~AbstractNode() override = default; | |||
| typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message); | |||
| typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| bool Broadcast(const enum NodeRole &node_role, const CommMessage &message, | |||
| using DataPtr = std::shared_ptr<unsigned char>; | |||
| using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | |||
| bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void set_event_callback(const OnNodeEventMessage &on_node_event_message); | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data, | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output, | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data, | |||
| const std::vector<size_t> &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command, | |||
| VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data, | |||
| const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data, | |||
| std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message); | |||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size); | |||
| std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| CommMessage *output); | |||
| void **output, size_t *size); | |||
| bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| protected: | |||
| void Register(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessRegisterResp(const CommMessage &message); | |||
| void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); | |||
| bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); | |||
| void UpdateSchedulerTime(); | |||
| bool CheckSchedulerTimeout() const; | |||
| void ProcessHeartbeatResp(const CommMessage &message); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessFetchServersResp(const CommMessage &message); | |||
| bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); | |||
| bool WaitForDisconnect(const uint32_t &timeout); | |||
| bool InitClientToScheduler(); | |||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | |||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| void ProcessSendDataResp(const CommMessage &message); | |||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &, | |||
| const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size); | |||
| void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size); | |||
| void RunMessageCallback(const uint64_t &request_id); | |||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback); | |||
| void RunReceiveCallback(const CommMessage &message); | |||
| void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta); | |||
| void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size); | |||
| uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); | |||
| uint64_t NextActualRankRequestId(const uint32_t &rank_id); | |||
| void InitCommandHandler(); | |||
| uint64_t AddMessageTrack(const uint32_t &expected_response); | |||
| bool CheckMessageTrack(const uint64_t &request_id); | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| std::unique_ptr<std::thread> client_to_scheduler_thread_; | |||
| @@ -98,15 +109,16 @@ class AbstractNode : public Node { | |||
| std::mutex message_tracker_mutex_; | |||
| std::condition_variable message_tracker_cond_; | |||
| // the key is: request_id, the value is:<rank_id, CommMessage> | |||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | |||
| // the key is: request_id, the value is: <rank_id, RecvMessage> | |||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_; | |||
| std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_; | |||
| std::mutex receive_messages_mutex_; | |||
| // the key is: request_id | |||
| std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | |||
| std::mutex message_callbacks_mutex_; | |||
| // the key is <rank_id, rank_request_id> | |||
| std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_; | |||
| std::map<std::pair<uint32_t, uint64_t>, std::shared_ptr<std::vector<unsigned char>>> received_data_; | |||
| std::mutex receive_callbacks_mutex_; | |||
| // the key is <rank_id, rank_request_id> | |||
| std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_; | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum class Protos : uint32_t { RAW = 0, PROTOBUF = 1, FLATBUFFERS = 2 }; | |||
| enum class Command { | |||
| TERMINATE = 0, | |||
| REGISTER = 1, | |||
| HEARTBEAT = 2, | |||
| SEND_DATA = 3, | |||
| FETCH_SERVER = 4, | |||
| FINISH = 5, | |||
| COLLECTIVE_SEND_DATA = 6 | |||
| }; | |||
| enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 }; | |||
| struct MessageHeader { | |||
| Protos message_proto_ = Protos::RAW; | |||
| uint32_t message_meta_length_ = 0; | |||
| uint64_t message_length_ = 0; | |||
| }; | |||
| struct CommandMeta { | |||
| // the command of this message,for example: register,heartbeat,data | |||
| Command cmd; | |||
| // the request id of this message | |||
| uint64_t request_id; | |||
| // the role of the current node: worker,server,scheduler | |||
| Role role; | |||
| // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] | |||
| int32_t rank_id = 4; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_ | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| syntax = "proto3"; | |||
| import "google/protobuf/any.proto"; | |||
| package mindspore.ps.core; | |||
| option optimize_for = LITE_RUNTIME; | |||
| @@ -44,6 +43,8 @@ message MessageMeta { | |||
| NodeRole role = 3; | |||
| // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] | |||
| int32 rank_id = 4; | |||
| // User-defined commands | |||
| int32 user_cmd = 5; | |||
| } | |||
| message RegisterMessage { | |||
| @@ -76,6 +77,10 @@ message HeartbeatRespMessage { | |||
| bool is_node_timeout = 4; | |||
| } | |||
| message FetchServersMessage { | |||
| string node_id = 1; | |||
| } | |||
| message FetchServersRespMessage { | |||
| repeated ServersMeta servers_meta = 1; | |||
| } | |||
| @@ -95,6 +100,4 @@ message FinishMessage { | |||
| message CommMessage { | |||
| MessageMeta pb_meta = 1; | |||
| bytes data = 2; | |||
| // User-defined commands | |||
| bytes user_cmd = 3; | |||
| } | |||
| @@ -38,9 +38,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| } | |||
| void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.ParseFromString(message->data()); | |||
| heartbeat_message.ParseFromArray(data, size); | |||
| node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); | |||
| @@ -60,10 +64,8 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha | |||
| heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); | |||
| heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| comm_message->set_data(heartbeat_resp_message.SerializeAsString()); | |||
| server->SendMessage(conn, comm_message); | |||
| server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(), | |||
| heartbeat_resp_message.ByteSizeLong()); | |||
| } | |||
| void SchedulerNode::Initialize() { | |||
| @@ -89,12 +91,13 @@ void SchedulerNode::CreateTcpServer() { | |||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||
| uint32_t scheduler_port = ClusterConfig::scheduler_port(); | |||
| server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port); | |||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| if (handlers_.count(message->pb_meta().cmd()) == 0) { | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; | |||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size) { | |||
| if (handlers_.count(meta->cmd()) == 0) { | |||
| MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; | |||
| } | |||
| const auto &handler_ptr = handlers_[message->pb_meta().cmd()]; | |||
| (this->*handler_ptr)(server_, conn, message); | |||
| const auto &handler_ptr = handlers_[meta->cmd()]; | |||
| (this->*handler_ptr)(server_, conn, meta, data, size); | |||
| }); | |||
| server_->Init(); | |||
| @@ -106,10 +109,14 @@ void SchedulerNode::CreateTcpServer() { | |||
| } | |||
| void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| MS_LOG(INFO) << "The scheduler process a register message!"; | |||
| RegisterMessage register_message; | |||
| register_message.ParseFromString(message->data()); | |||
| register_message.ParseFromArray(data, size); | |||
| // assign worker node and server node rank id | |||
| int rank_id = node_manager_.NextRankId(register_message); | |||
| @@ -123,32 +130,32 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar | |||
| register_resp_message.set_node_id(node_id); | |||
| register_resp_message.set_rank_id(rank_id); | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| comm_message->set_data(register_resp_message.SerializeAsString()); | |||
| server->SendMessage(conn, comm_message); | |||
| server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(), | |||
| register_resp_message.ByteSizeLong()); | |||
| } | |||
| void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| FinishMessage finish_message; | |||
| finish_message.ParseFromString(message->data()); | |||
| finish_message.ParseFromArray(data, size); | |||
| node_manager_.AddFinishNode(finish_message); | |||
| MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); | |||
| server->SendMessage(conn, message); | |||
| server->SendMessage(conn, meta, Protos::PROTOBUF, data, size); | |||
| } | |||
| void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size) { | |||
| FetchServersRespMessage fetch_servers_message; | |||
| std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); | |||
| *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| comm_message->set_data(fetch_servers_message.SerializeAsString()); | |||
| server->SendMessage(conn, comm_message); | |||
| server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(), | |||
| fetch_servers_message.ByteSizeLong()); | |||
| } | |||
| void SchedulerNode::StartUpdateClusterStateTimer() { | |||
| @@ -36,13 +36,14 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class SchedulerNode : public Node { | |||
| public: | |||
| SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | |||
| ~SchedulerNode() override; | |||
| typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; | |||
| bool Stop() override; | |||
| @@ -53,14 +54,14 @@ class SchedulerNode : public Node { | |||
| void InitCommandHandler(); | |||
| void CreateTcpServer(); | |||
| void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| void StartUpdateClusterStateTimer(); | |||
| void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> scheduler_thread_; | |||
| @@ -46,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) { | |||
| void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } | |||
| void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data, | |||
| size_t size) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| message->mutable_pb_meta()->set_role(node_info_.node_role_); | |||
| message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_); | |||
| const MessageMeta &message_meta = message->pb_meta(); | |||
| const uint64_t request_id = message_meta.request_id(); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| meta->set_role(node_info_.node_role_); | |||
| meta->set_rank_id(node_info_.rank_id_); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| server_->SendMessage(conn, message); | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id(); | |||
| server_->SendMessage(conn, meta, Protos::RAW, data.get(), size); | |||
| } | |||
| void ServerNode::CreateTcpServer() { | |||
| @@ -63,17 +63,18 @@ void ServerNode::CreateTcpServer() { | |||
| std::string server_ip; | |||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); | |||
| server_ = std::make_shared<TcpServer>(server_ip, 0); | |||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| switch (message->pb_meta().cmd()) { | |||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size) { | |||
| switch (meta->cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendData(conn, message); | |||
| ProcessSendData(conn, meta, protos, data, size); | |||
| break; | |||
| case NodeCommand::COLLECTIVE_SEND_DATA: | |||
| ProcessCollectiveSendData(conn, message); | |||
| RunReceiveCallback(*message); | |||
| ProcessCollectiveSendData(conn, meta, data, size); | |||
| RunReceiveCallback(meta, protos, data, size); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; | |||
| MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; | |||
| } | |||
| }); | |||
| server_->Init(); | |||
| @@ -99,18 +100,24 @@ void ServerNode::Initialize() { | |||
| MS_LOG(INFO) << "Server node init client successful!"; | |||
| } | |||
| void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| request_handler_(conn, message); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[size]); | |||
| int ret = memcpy_s(res.get(), size, data, size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| request_handler_(conn, meta, res, size); | |||
| } | |||
| void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| server_->SendMessage(conn, comm_message); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| server_->SendMessage(conn, meta, Protos::RAW, data, size); | |||
| } | |||
| bool ServerNode::Stop() { | |||
| @@ -23,6 +23,7 @@ | |||
| #include <string> | |||
| #include <thread> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/tcp_client.h" | |||
| @@ -41,16 +42,19 @@ class ServerNode : public AbstractNode { | |||
| bool Stop() override; | |||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; | |||
| using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| DataPtr data, size_t size)>; | |||
| void set_handler(const RequestHandler &handler); | |||
| void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data, size_t size); | |||
| private: | |||
| void CreateTcpServer(); | |||
| void Initialize(); | |||
| void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos, | |||
| const void *data, size_t size); | |||
| void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| const void *data, size_t size); | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> server_thread_; | |||
| @@ -46,11 +46,12 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||
| server_port_(port), | |||
| is_stop_(true), | |||
| is_connected_(false) { | |||
| message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) { | |||
| if (message_callback_) { | |||
| message_callback_(*this, *message); | |||
| } | |||
| }); | |||
| message_handler_.SetCallback( | |||
| [this](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) { | |||
| if (message_callback_) { | |||
| message_callback_(meta, protos, data, size); | |||
| } | |||
| }); | |||
| } | |||
| TcpClient::~TcpClient() { | |||
| @@ -189,7 +190,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { | |||
| void TcpClient::OnReadHandler(const void *buf, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(buf); | |||
| if (read_callback_) { | |||
| read_callback_(*this, buf, num); | |||
| read_callback_(buf, num); | |||
| } | |||
| message_handler_.ReceiveMessage(buf, num); | |||
| } | |||
| @@ -198,7 +199,7 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | |||
| if (tcp_client->on_timer_callback_) { | |||
| tcp_client->on_timer_callback_(*tcp_client); | |||
| tcp_client->on_timer_callback_(); | |||
| } | |||
| } | |||
| @@ -245,7 +246,7 @@ void TcpClient::Start() { | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | |||
| << "Event base dispatch failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!"; | |||
| } | |||
| void TcpClient::StartWithNoBlock() { | |||
| @@ -256,7 +257,7 @@ void TcpClient::StartWithNoBlock() { | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!"; | |||
| } | |||
| void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } | |||
| @@ -265,14 +266,49 @@ bool TcpClient::SendMessage(const CommMessage &message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| bufferevent_lock(buffer_event_); | |||
| bool res = true; | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||
| if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { | |||
| size_t buf_size = IntToUint(message.ByteSizeLong()); | |||
| uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong()); | |||
| Messageheader header; | |||
| header.message_proto_ = Protos::PROTOBUF; | |||
| header.message_length_ = buf_size; | |||
| header.message_meta_length_ = meta_size; | |||
| if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||
| res = false; | |||
| } | |||
| if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| bufferevent_unlock(buffer_event_); | |||
| return res; | |||
| } | |||
| bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| bufferevent_lock(buffer_event_); | |||
| bool res = true; | |||
| Messageheader header; | |||
| header.message_proto_ = protos; | |||
| header.message_meta_length_ = SizeToUint(meta->ByteSizeLong()); | |||
| header.message_length_ = size + header.message_meta_length_; | |||
| if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||
| res = false; | |||
| } | |||
| if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { | |||
| if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| if (bufferevent_write(buffer_event_, data, size) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| @@ -42,10 +42,10 @@ class TcpClient { | |||
| public: | |||
| using OnConnected = std::function<void()>; | |||
| using OnDisconnected = std::function<void()>; | |||
| using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | |||
| using OnTimeout = std::function<void(const TcpClient &)>; | |||
| using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; | |||
| using OnTimer = std::function<void(const TcpClient &)>; | |||
| using OnRead = std::function<void(const void *, size_t)>; | |||
| using OnTimeout = std::function<void()>; | |||
| using OnMessage = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>; | |||
| using OnTimer = std::function<void()>; | |||
| explicit TcpClient(const std::string &address, std::uint16_t port); | |||
| virtual ~TcpClient(); | |||
| @@ -61,6 +61,7 @@ class TcpClient { | |||
| void StartWithNoBlock(); | |||
| void SetMessageCallback(const OnMessage &cb); | |||
| bool SendMessage(const CommMessage &message) const; | |||
| bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size); | |||
| void StartTimer(const uint32_t &time); | |||
| void set_timer_callback(const OnTimer &timer); | |||
| const event_base &eventbase(); | |||
| @@ -35,8 +35,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| header_[++header_index_] = *(buffer_data + i); | |||
| --num; | |||
| if (header_index_ == kHeaderLen - 1) { | |||
| message_length_ = *reinterpret_cast<const size_t *>(header_); | |||
| remaining_length_ = message_length_; | |||
| message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_); | |||
| message_header_.message_meta_length_ = | |||
| *reinterpret_cast<const uint32_t *>(header_ + sizeof(message_header_.message_proto_)); | |||
| message_header_.message_length_ = *reinterpret_cast<const size_t *>( | |||
| header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_)); | |||
| remaining_length_ = message_header_.message_length_; | |||
| message_buffer_.reset(new unsigned char[remaining_length_]); | |||
| buffer_data += (i + 1); | |||
| break; | |||
| @@ -57,10 +61,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| } | |||
| if (remaining_length_ == 0) { | |||
| std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>(); | |||
| pb_message->ParseFromArray(message_buffer_.get(), message_length_); | |||
| if (message_callback_) { | |||
| message_callback_(pb_message); | |||
| std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>(); | |||
| pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_); | |||
| message_callback_(pb_message, message_header_.message_proto_, | |||
| message_buffer_.get() + message_header_.message_meta_length_, | |||
| message_header_.message_length_ - message_header_.message_meta_length_); | |||
| } | |||
| message_buffer_.reset(); | |||
| message_buffer_ = nullptr; | |||
| @@ -24,24 +24,20 @@ | |||
| #include <vector> | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/core/message.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>; | |||
| constexpr int kHeaderLen = 8; | |||
| using messageReceive = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>; | |||
| constexpr int kHeaderLen = 16; | |||
| class TcpMessageHandler { | |||
| public: | |||
| TcpMessageHandler() | |||
| : is_parsed_(false), | |||
| message_buffer_(nullptr), | |||
| message_length_(0), | |||
| remaining_length_(0), | |||
| header_index_(-1), | |||
| last_copy_len_(0) {} | |||
| : is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {} | |||
| virtual ~TcpMessageHandler() = default; | |||
| void SetCallback(const messageReceive &cb); | |||
| @@ -51,11 +47,12 @@ class TcpMessageHandler { | |||
| messageReceive message_callback_; | |||
| bool is_parsed_; | |||
| std::unique_ptr<unsigned char> message_buffer_; | |||
| size_t message_length_; | |||
| size_t remaining_length_; | |||
| char header_[8]; | |||
| char header_[16]; | |||
| int header_index_; | |||
| size_t last_copy_len_; | |||
| MessageHeader message_header_; | |||
| std::string mBuffer; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -54,13 +54,39 @@ bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const { | |||
| bufferevent_lock(buffer_event_); | |||
| bool res = true; | |||
| size_t buf_size = message->ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message->SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||
| if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||
| res = false; | |||
| } | |||
| if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { | |||
| if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| bufferevent_unlock(buffer_event_); | |||
| return res; | |||
| } | |||
| bool TcpConnection::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, | |||
| size_t size) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| bufferevent_lock(buffer_event_); | |||
| bool res = true; | |||
| Messageheader header; | |||
| header.message_proto_ = protos; | |||
| header.message_meta_length_ = SizeToUint(meta->ByteSizeLong()); | |||
| header.message_length_ = size + header.message_meta_length_; | |||
| if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||
| res = false; | |||
| } | |||
| if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| if (bufferevent_write(buffer_event_, data, size) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| @@ -158,7 +184,7 @@ void TcpServer::Start() { | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | |||
| << "Event base dispatch failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!"; | |||
| } | |||
| void TcpServer::StartWithNoBlock() { | |||
| @@ -169,7 +195,7 @@ void TcpServer::StartWithNoBlock() { | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!"; | |||
| } | |||
| void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { | |||
| @@ -260,10 +286,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| server->AddConnection(fd, conn); | |||
| conn->InitConnection([=](std::shared_ptr<CommMessage> message) { | |||
| conn->InitConnection([=](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) { | |||
| OnServerReceiveMessage on_server_receive = server->GetServerReceive(); | |||
| if (on_server_receive) { | |||
| on_server_receive(conn, message); | |||
| on_server_receive(conn, meta, protos, data, size); | |||
| } | |||
| }); | |||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, | |||
| @@ -274,6 +300,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| } | |||
| std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| std::shared_ptr<TcpConnection> conn = nullptr; | |||
| if (client_accept_) { | |||
| conn = (client_accept_(*this)); | |||
| @@ -367,9 +394,17 @@ bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr | |||
| return conn->SendMessage(message); | |||
| } | |||
| bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| return conn->SendMessage(meta, protos, data, size); | |||
| } | |||
| void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) { | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | |||
| SendMessage(it->second, message); | |||
| @@ -36,7 +36,6 @@ | |||
| #include "ps/core/tcp_message_handler.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/convert_utils_base.h" | |||
| namespace mindspore { | |||
| @@ -55,6 +54,7 @@ class TcpConnection { | |||
| virtual void InitConnection(const messageReceive &callback); | |||
| virtual void SendMessage(const void *buffer, size_t num) const; | |||
| bool SendMessage(std::shared_ptr<CommMessage> message) const; | |||
| bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) const; | |||
| virtual void OnReadHandler(const void *buffer, size_t numBytes); | |||
| TcpServer *GetServer() const; | |||
| const evutil_socket_t &GetFd() const; | |||
| @@ -69,7 +69,8 @@ class TcpConnection { | |||
| }; | |||
| using OnServerReceiveMessage = | |||
| std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; | |||
| std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos, | |||
| const void *data, size_t size)>; | |||
| class TcpServer { | |||
| public: | |||
| @@ -100,6 +101,8 @@ class TcpServer { | |||
| OnServerReceiveMessage GetServerReceive() const; | |||
| void SetMessageCallback(const OnServerReceiveMessage &cb); | |||
| bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos, | |||
| const void *data, size_t sizee); | |||
| void SendMessage(std::shared_ptr<CommMessage> message); | |||
| uint16_t BoundPort() const; | |||
| std::string BoundIp() const; | |||
| @@ -30,7 +30,12 @@ class TestTcpClient : public UT::Common { | |||
| TEST_F(TestTcpClient, InitClientIPError) { | |||
| auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000); | |||
| client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); | |||
| client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) { | |||
| CommMessage message; | |||
| message.ParseFromArray(data, size); | |||
| client->SendMessage(message); | |||
| }); | |||
| ASSERT_THROW(client->Init(), std::exception); | |||
| } | |||
| @@ -38,10 +43,15 @@ TEST_F(TestTcpClient, InitClientIPError) { | |||
| TEST_F(TestTcpClient, InitClientPortErrorNoException) { | |||
| auto client = std::make_unique<TcpClient>("127.0.0.1", -1); | |||
| client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); | |||
| client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) { | |||
| CommMessage message; | |||
| message.ParseFromArray(data, size); | |||
| client->SendMessage(message); | |||
| }); | |||
| EXPECT_NO_THROW(client->Init()); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -33,130 +33,144 @@ class TestTcpMessageHandler : public UT::Common { | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | |||
| TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); | |||
| handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) { | |||
| EXPECT_EQ(meta->ByteSizeLong(), 2); | |||
| EXPECT_EQ(size, 1000); | |||
| }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[1011]; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| char result[1018]; | |||
| MessageMeta meta; | |||
| meta.set_request_id(1); | |||
| EXPECT_EQ(meta.ByteSizeLong(), 2); | |||
| MessageHeader header; | |||
| header.message_proto_ = Protos::RAW; | |||
| header.message_meta_length_ = meta.ByteSizeLong(); | |||
| header.message_length_ = data.length() + meta.ByteSizeLong(); | |||
| int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| handler.ReceiveMessage(result, buf_size + kHeaderLen); | |||
| memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); | |||
| memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); | |||
| handler.ReceiveMessage(result, 1018); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | |||
| TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data_16Header_2meta_1000Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); | |||
| handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) { | |||
| EXPECT_EQ(meta->ByteSizeLong(), 2); | |||
| EXPECT_EQ(size, 1000); | |||
| }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[2022] = {0}; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| char result[2036]; | |||
| MessageMeta meta; | |||
| meta.set_request_id(1); | |||
| EXPECT_EQ(meta.ByteSizeLong(), 2); | |||
| MessageHeader header; | |||
| header.message_proto_ = Protos::RAW; | |||
| header.message_meta_length_ = meta.ByteSizeLong(); | |||
| header.message_length_ = data.length() + meta.ByteSizeLong(); | |||
| int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2); | |||
| memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); | |||
| memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); | |||
| memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen); | |||
| memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() + data.length(), meta.ByteSizeLong(), | |||
| meta.SerializeAsString().data(), meta.ByteSizeLong()); | |||
| memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() * 2 + data.length(), data.length(), data.data(), | |||
| data.length()); | |||
| handler.ReceiveMessage(result, 2036); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | |||
| TEST_F(TestTcpMessageHandler, 16header_2meta_4070data_8header_8header_2meta_4070data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); }); | |||
| handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) { | |||
| EXPECT_EQ(meta->ByteSizeLong(), 2); | |||
| EXPECT_EQ(size, 4070); | |||
| }); | |||
| std::string data(4070, 'a'); | |||
| std::string data(4081, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[4096] = {0}; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4); | |||
| MessageMeta meta; | |||
| meta.set_request_id(1); | |||
| EXPECT_EQ(meta.ByteSizeLong(), 2); | |||
| MessageHeader header; | |||
| header.message_proto_ = Protos::RAW; | |||
| header.message_meta_length_ = meta.ByteSizeLong(); | |||
| header.message_length_ = data.length() + meta.ByteSizeLong(); | |||
| int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); | |||
| memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); | |||
| memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), 8, &header, 8); | |||
| handler.ReceiveMessage(result, 4096); | |||
| auto temp = reinterpret_cast<char *>(&buf_size); | |||
| ret = memcpy_s(result, 4, temp + 4, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| auto temp = reinterpret_cast<char *>(&header); | |||
| memcpy_s(result, 8, temp + 8, 8); | |||
| memcpy_s(result + 8, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); | |||
| memcpy_s(result + 8 + 2, data.length(), data.data(), data.length()); | |||
| handler.ReceiveMessage(result, 4088); | |||
| handler.ReceiveMessage(result, 4080); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { | |||
| TEST_F(TestTcpMessageHandler, 16Header_2meta_4062Data_16Header_2meta_4062_data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); }); | |||
| handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) { | |||
| EXPECT_EQ(meta->ByteSizeLong(), 2); | |||
| EXPECT_EQ(size, 4062); | |||
| }); | |||
| std::string data(4062, 'a'); | |||
| std::string data(4077, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| char result[4096] = {0}; | |||
| int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); | |||
| MessageMeta meta; | |||
| meta.set_request_id(1); | |||
| EXPECT_EQ(meta.ByteSizeLong(), 2); | |||
| MessageHeader header; | |||
| header.message_proto_ = Protos::RAW; | |||
| header.message_meta_length_ = meta.ByteSizeLong(); | |||
| header.message_length_ = data.length() + meta.ByteSizeLong(); | |||
| int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); | |||
| memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); | |||
| memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen); | |||
| handler.ReceiveMessage(result, 4096); | |||
| ret = memcpy_s(result, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| memcpy_s(result, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); | |||
| memcpy_s(result + meta.ByteSizeLong(), data.length(), data.data(), data.length()); | |||
| handler.ReceiveMessage(result, 4080); | |||
| handler.ReceiveMessage(result, 4064); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -33,11 +33,12 @@ class TestTcpServer : public UT::Common { | |||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 0); | |||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | |||
| http_server_thread_ = std::make_unique<std::thread>([=]() { | |||
| server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, | |||
| const Protos &protos, const void *data, size_t size) { | |||
| KVMessage kv_message; | |||
| kv_message.ParseFromString(message->data()); | |||
| kv_message.ParseFromArray(data, size); | |||
| EXPECT_EQ(2, kv_message.keys_size()); | |||
| server_->SendMessage(conn, message); | |||
| server_->SendMessage(conn, meta, protos, data, size); | |||
| }); | |||
| server_->Init(); | |||
| server_->Start(); | |||
| @@ -61,23 +62,24 @@ TEST_F(TestTcpServer, ServerSendMessage) { | |||
| std::cout << server_->BoundPort() << std::endl; | |||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | |||
| http_client_thread = std::make_unique<std::thread>([&]() { | |||
| client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | |||
| KVMessage kv_message; | |||
| kv_message.ParseFromString(message.data()); | |||
| EXPECT_EQ(2, kv_message.keys_size()); | |||
| client_->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) { | |||
| KVMessage message; | |||
| message.ParseFromArray(data, size); | |||
| EXPECT_EQ(2, message.keys_size()); | |||
| }); | |||
| client_->Init(); | |||
| CommMessage comm_message; | |||
| KVMessage kv_message; | |||
| std::vector<int> keys{1, 2}; | |||
| std::vector<int> values{3, 4}; | |||
| *kv_message.mutable_keys() = {keys.begin(), keys.end()}; | |||
| *kv_message.mutable_values() = {values.begin(), values.end()}; | |||
| comm_message.set_data(kv_message.SerializeAsString()); | |||
| client_->SendMessage(comm_message); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| message_meta->set_cmd(NodeCommand::SEND_DATA); | |||
| client_->SendMessage(message_meta, Protos::RAW, kv_message.SerializeAsString().data(), kv_message.ByteSizeLong()); | |||
| client_->Start(); | |||
| }); | |||