| @@ -75,6 +75,8 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string & | |||
| auto client = GetOrCreateTcpClient((*it).first.second); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| @@ -126,11 +128,13 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *output, const uint32_t &timeout) { | |||
| std::string *output, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| @@ -141,7 +145,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| *output = res[rank_id]; | |||
| *output = res[rank_id].data(); | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| @@ -157,11 +161,13 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| client->SendMessage(comm_message); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage> *output, | |||
| const std::vector<std::string> &data, std::vector<std::string> *output, | |||
| const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| uint64_t request_id = ++next_request_id_; | |||
| @@ -177,7 +183,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| for (size_t it = 0; it < len; ++it) { | |||
| (*output).push_back(res[rank_ids.at(it)]); | |||
| (*output).push_back(res[rank_ids.at(it)].data()); | |||
| } | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| @@ -201,6 +207,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | |||
| client->SendMessage(comm_message); | |||
| } | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| @@ -215,7 +223,7 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { | |||
| } | |||
| uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| const std::string &message, const uint32_t &timeout) { | |||
| const std::string &message) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| @@ -233,19 +241,19 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const | |||
| } | |||
| std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, | |||
| const uint32_t &rank_id, CommMessage *output) { | |||
| const uint32_t &rank_id, std::string *output) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); | |||
| if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { | |||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)].data(); | |||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||
| } else { | |||
| set_receive_callback(rank_id, rank_request_id, [=]() { | |||
| receive_callbacks_mutex_.lock(); | |||
| *output = received_data_[std::make_pair(rank_id, 1)]; | |||
| *output = received_data_[std::make_pair(rank_id, 1)].data(); | |||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||
| receive_callbacks_mutex_.unlock(); | |||
| }); | |||
| @@ -272,13 +280,25 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||
| << " begin send heartbeat to the scheduler!"; | |||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | |||
| while (!is_finish_.load()) { | |||
| Heartbeat(client); | |||
| if (!Heartbeat(client)) { | |||
| MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| if (!CheckSchedulerTimeout() && on_node_event_message_) { | |||
| MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; | |||
| is_finish_ = true; | |||
| wait_finish_cond_.notify_all(); | |||
| on_node_event_message_(NodeEvent::SCHEDULER_TIMEOUT); | |||
| } | |||
| } else { | |||
| UpdateSchedulerTime(); | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||
| } | |||
| }); | |||
| } | |||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||
| bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||
| MessageMeta meta; | |||
| meta.set_cmd(NodeCommand::HEARTBEAT); | |||
| @@ -292,11 +312,31 @@ void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| } | |||
| return true; | |||
| } | |||
| void AbstractNode::UpdateSchedulerTime() { | |||
| struct timeval current_time {}; | |||
| (void)gettimeofday(¤t_time, nullptr); | |||
| scheduler_time_ = current_time; | |||
| MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||
| << " update scheduler time, the current time is: " << current_time.tv_sec; | |||
| } | |||
| bool AbstractNode::CheckSchedulerTimeout() const { | |||
| struct timeval current_time {}; | |||
| (void)gettimeofday(¤t_time, nullptr); | |||
| if (scheduler_time_.tv_sec + ClusterConfig::scheduler_timeout() < current_time.tv_sec) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | |||
| HeartbeatRespMessage heartbeat_resp_message; | |||
| heartbeat_resp_message.ParseFromString(message.data()); | |||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | |||
| if (is_ready_.load()) { | |||
| wait_start_cond_.notify_all(); | |||
| @@ -353,9 +393,9 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui | |||
| *message.mutable_pb_meta() = {meta}; | |||
| message.set_data(finish_message.SerializeAsString()); | |||
| if (!SendMessageSync(client, message)) { | |||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||
| MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||
| return WaitForDisconnect(timeout); | |||
| } | |||
| @@ -444,6 +484,8 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| @@ -452,6 +494,8 @@ uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return request_id; | |||
| } | |||
| @@ -460,6 +504,8 @@ void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | |||
| const MessageMeta &message_meta = message.pb_meta(); | |||
| const uint32_t &rank_id = message_meta.rank_id(); | |||
| const uint64_t request_id = message_meta.request_id(); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| auto it = receive_messages_.find(request_id); | |||
| if (it != receive_messages_.end()) { | |||
| it->second[rank_id] = message; | |||
| @@ -42,23 +42,24 @@ class AbstractNode : public Node { | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output, | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message); | |||
| std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| CommMessage *output); | |||
| std::string *output); | |||
| bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| protected: | |||
| void Register(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessRegisterResp(const CommMessage &message); | |||
| void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); | |||
| void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); | |||
| bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); | |||
| void UpdateSchedulerTime(); | |||
| bool CheckSchedulerTimeout() const; | |||
| void ProcessHeartbeatResp(const CommMessage &message); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessFetchServersResp(const CommMessage &message); | |||
| @@ -113,6 +114,7 @@ class AbstractNode : public Node { | |||
| // the key is rank_id, the value is rank_id's actual request_id | |||
| std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; | |||
| std::mutex rank_request_ids_mutex; | |||
| timeval scheduler_time_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -33,15 +33,17 @@ uint32_t ClusterConfig::heartbeat_timeout_ = 30; | |||
| uint32_t ClusterConfig::cluster_available_timeout_ = 300; | |||
| // The timeout period for the client to connect to the server is 100ms. | |||
| uint32_t ClusterConfig::connect_interval_ = 100; | |||
| // When the scheduler exits, the worker and server can continue to work for 5 hours | |||
| uint32_t ClusterConfig::scheduler_timeout_ = 3600 * 5; | |||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | |||
| std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | |||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, | |||
| const uint16_t &scheduler_port) { | |||
| worker_num_ = worker_num; | |||
| server_num_ = server_num; | |||
| if (!CommUtil::CheckIp(*scheduler_host.get())) { | |||
| MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!"; | |||
| if (!CommUtil::CheckIp(scheduler_host)) { | |||
| MS_LOG(EXCEPTION) << "The scheduler_host:" << scheduler_host << " is illegal!"; | |||
| } | |||
| scheduler_host_ = std::move(scheduler_host); | |||
| scheduler_host_ = std::make_unique<std::string>(scheduler_host); | |||
| scheduler_port_ = scheduler_port; | |||
| } | |||
| @@ -55,7 +57,7 @@ void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) { | |||
| heartbeat_interval_ = heartbeat_interval; | |||
| } | |||
| std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } | |||
| std::string ClusterConfig::scheduler_host() { return *scheduler_host_; } | |||
| uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } | |||
| @@ -74,6 +76,10 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa | |||
| uint32_t ClusterConfig::connect_interval() { return connect_interval_; } | |||
| void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } | |||
| uint32_t ClusterConfig::scheduler_timeout() { return scheduler_timeout_; } | |||
| void ClusterConfig::set_scheduler_timeout(const uint32_t &scheduler_timeout) { scheduler_timeout_ = scheduler_timeout; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -30,7 +30,7 @@ namespace ps { | |||
| namespace core { | |||
| class ClusterConfig { | |||
| public: | |||
| static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, | |||
| static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, | |||
| const uint16_t &scheduler_port); | |||
| static uint32_t worker_num(); | |||
| static uint32_t server_num(); | |||
| @@ -44,6 +44,8 @@ class ClusterConfig { | |||
| static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); | |||
| static uint32_t connect_interval(); | |||
| static void set_connect_interval(const uint32_t &connect_interval); | |||
| static uint32_t scheduler_timeout(); | |||
| static void set_scheduler_timeout(const uint32_t &scheduler_timeout); | |||
| private: | |||
| static uint32_t worker_num_; | |||
| @@ -54,6 +56,7 @@ class ClusterConfig { | |||
| static uint32_t heartbeat_timeout_; | |||
| static uint32_t cluster_available_timeout_; | |||
| static uint32_t connect_interval_; | |||
| static uint32_t scheduler_timeout_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -21,7 +21,12 @@ namespace ps { | |||
| namespace core { | |||
| std::string Node::node_id() const { return node_info_.node_id_; } | |||
| uint32_t Node::rank_id() const { return node_info_.rank_id_; } | |||
| uint32_t Node::rank_id() const { | |||
| if (!is_ready_.load()) { | |||
| MS_LOG(EXCEPTION) << "The cluster is not ready yet to get rank id!"; | |||
| } | |||
| return node_info_.rank_id_; | |||
| } | |||
| NodeRole Node::role() const { return node_info_.node_role_; } | |||
| @@ -30,8 +30,6 @@ | |||
| #include <utility> | |||
| #include <tuple> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/node_info.h" | |||
| #include "ps/core/tcp_client.h" | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; | |||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT }; | |||
| struct NodeInfo { | |||
| NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | |||
| @@ -64,8 +64,8 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) { | |||
| struct timeval current_time {}; | |||
| (void)gettimeofday(¤t_time, nullptr); | |||
| heartbeats_[node_id] = current_time; | |||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id | |||
| << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; | |||
| MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id | |||
| << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; | |||
| } | |||
| void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); } | |||
| @@ -31,8 +31,6 @@ | |||
| #include <condition_variable> | |||
| #include <unordered_set> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/node.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/convert_utils_base.h" | |||
| @@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME; | |||
| enum PSCommand { | |||
| PUSH = 0; | |||
| PULL = 1; | |||
| INIT_EMBEDDING_TABLE = 2; | |||
| } | |||
| message KVMessage { | |||
| @@ -37,9 +37,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| return true; | |||
| } | |||
| void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| HeartbeatMessage heartbeat_message; | |||
| heartbeat_message.ParseFromString(message.data()); | |||
| heartbeat_message.ParseFromString(message->data()); | |||
| node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); | |||
| @@ -59,10 +60,10 @@ void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnectio | |||
| heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); | |||
| heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| comm_message.set_data(heartbeat_resp_message.SerializeAsString()); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| comm_message->set_data(heartbeat_resp_message.SerializeAsString()); | |||
| server->SendMessage(conn, comm_message); | |||
| } | |||
| void SchedulerNode::Initialize() { | |||
| @@ -79,23 +80,23 @@ void SchedulerNode::CreateTcpServer() { | |||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | |||
| uint32_t scheduler_port = ClusterConfig::scheduler_port(); | |||
| server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port); | |||
| server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port); | |||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| switch (message->pb_meta().cmd()) { | |||
| case NodeCommand::HEARTBEAT: | |||
| ProcessHeartbeat(server, conn, message); | |||
| ProcessHeartbeat(server_, conn, message); | |||
| break; | |||
| case NodeCommand::REGISTER: | |||
| ProcessRegister(server, conn, message); | |||
| ProcessRegister(server_, conn, message); | |||
| break; | |||
| case NodeCommand::FINISH: | |||
| ProcessFinish(server, conn, message); | |||
| ProcessFinish(server_, conn, message); | |||
| break; | |||
| case NodeCommand::FETCH_SERVER: | |||
| ProcessFetchServers(server, conn, message); | |||
| ProcessFetchServers(server_, conn, message); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| }); | |||
| @@ -107,10 +108,11 @@ void SchedulerNode::CreateTcpServer() { | |||
| }); | |||
| } | |||
| void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| MS_LOG(INFO) << "The scheduler process a register message!"; | |||
| RegisterMessage register_message; | |||
| register_message.ParseFromString(message.data()); | |||
| register_message.ParseFromString(message->data()); | |||
| // assign worker node and server node rank id | |||
| int rank_id = node_manager_.NextRankId(register_message); | |||
| @@ -124,31 +126,32 @@ void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection | |||
| register_resp_message.set_node_id(node_id); | |||
| register_resp_message.set_rank_id(rank_id); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| comm_message.set_data(register_resp_message.SerializeAsString()); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| comm_message->set_data(register_resp_message.SerializeAsString()); | |||
| server->SendMessage(conn, comm_message); | |||
| } | |||
| void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| FinishMessage finish_message; | |||
| finish_message.ParseFromString(message.data()); | |||
| finish_message.ParseFromString(message->data()); | |||
| node_manager_.AddFinishNode(finish_message); | |||
| MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, message); | |||
| server->SendMessage(conn, message); | |||
| } | |||
| void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, | |||
| const CommMessage &message) { | |||
| void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message) { | |||
| FetchServersRespMessage fetch_servers_message; | |||
| std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); | |||
| *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| comm_message.set_data(fetch_servers_message.SerializeAsString()); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| comm_message->set_data(fetch_servers_message.SerializeAsString()); | |||
| server->SendMessage(conn, comm_message); | |||
| } | |||
| void SchedulerNode::StartUpdateClusterStateTimer() { | |||
| @@ -26,8 +26,6 @@ | |||
| #include <thread> | |||
| #include <mutex> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| @@ -51,13 +49,17 @@ class SchedulerNode : public Node { | |||
| private: | |||
| void Initialize(); | |||
| void CreateTcpServer(); | |||
| void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| void StartUpdateClusterStateTimer(); | |||
| void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||
| std::shared_ptr<CommMessage> message); | |||
| std::unique_ptr<TcpServer> server_; | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> scheduler_thread_; | |||
| std::unique_ptr<std::thread> update_state_thread_; | |||
| @@ -30,7 +30,8 @@ bool ServerNode::Start(const uint32_t &timeout) { | |||
| StartHeartbeatTimer(client_to_scheduler_); | |||
| if (!WaitForStart(timeout)) { | |||
| MS_LOG(ERROR) << "Start Server node timeout!"; | |||
| MS_LOG(ERROR) << "Start server node timeout!"; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "The cluster is ready to use!"; | |||
| @@ -45,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) { | |||
| void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } | |||
| void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||
| const std::string &message) { | |||
| auto &meta = const_cast<MessageMeta &>(message_meta); | |||
| meta.set_role(node_info_.node_role_); | |||
| meta.set_rank_id(node_info_.rank_id_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {meta}; | |||
| comm_message.set_data(message); | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| message->mutable_pb_meta()->set_role(node_info_.node_role_); | |||
| message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_); | |||
| const MessageMeta &message_meta = message->pb_meta(); | |||
| const uint64_t request_id = message_meta.request_id(); | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| server_->SendMessage(conn, message); | |||
| } | |||
| void ServerNode::CreateTcpServer() { | |||
| @@ -62,17 +63,17 @@ void ServerNode::CreateTcpServer() { | |||
| std::string server_ip; | |||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); | |||
| server_ = std::make_shared<TcpServer>(server_ip, 0); | |||
| server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| switch (message.pb_meta().cmd()) { | |||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| switch (message->pb_meta().cmd()) { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendData(server, conn, message); | |||
| ProcessSendData(conn, message); | |||
| break; | |||
| case NodeCommand::COLLECTIVE_SEND_DATA: | |||
| ProcessCollectiveSendData(server, conn, message); | |||
| RunReceiveCallback(message); | |||
| ProcessCollectiveSendData(conn, message); | |||
| RunReceiveCallback(*message); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| }); | |||
| server_->Init(); | |||
| @@ -97,15 +98,18 @@ void ServerNode::Initialize() { | |||
| MS_LOG(INFO) << "Server node init client successful!"; | |||
| } | |||
| void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||
| void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| request_handler_(conn, message); | |||
| } | |||
| void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, | |||
| const CommMessage &message) { | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||
| server_->SendMessage(conn, comm_message); | |||
| } | |||
| bool ServerNode::Stop() { | |||
| @@ -44,18 +44,16 @@ class ServerNode : public AbstractNode { | |||
| bool Stop() override; | |||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta, | |||
| const std::string &message)>; | |||
| using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; | |||
| void set_handler(const RequestHandler &handler); | |||
| void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||
| const std::string &message); | |||
| void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| private: | |||
| void CreateTcpServer(); | |||
| void Initialize(); | |||
| void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> server_thread_; | |||
| @@ -46,9 +46,9 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||
| server_port_(port), | |||
| is_stop_(true), | |||
| is_connected_(false) { | |||
| message_handler_.SetCallback([this](const CommMessage &message) { | |||
| message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) { | |||
| if (message_callback_) { | |||
| message_callback_(*this, message); | |||
| message_callback_(*this, *message); | |||
| } | |||
| }); | |||
| } | |||
| @@ -105,7 +105,7 @@ void TcpClient::Init() { | |||
| sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); | |||
| sin.sin_port = htons(server_port_); | |||
| buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); | |||
| buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); | |||
| @@ -261,17 +261,23 @@ void TcpClient::StartWithNoBlock() { | |||
| void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } | |||
| void TcpClient::SendMessage(const CommMessage &message) const { | |||
| bool TcpClient::SendMessage(const CommMessage &message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| bufferevent_lock(buffer_event_); | |||
| bool res = true; | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | |||
| if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||
| res = false; | |||
| } | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; | |||
| if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| bufferevent_unlock(buffer_event_); | |||
| return res; | |||
| } | |||
| void TcpClient::StartTimer(const uint32_t &time) { | |||
| @@ -33,8 +33,6 @@ | |||
| #include <condition_variable> | |||
| #include "ps/core/cluster_config.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "utils/convert_utils_base.h" | |||
| namespace mindspore { | |||
| @@ -62,7 +60,7 @@ class TcpClient { | |||
| void Start(); | |||
| void StartWithNoBlock(); | |||
| void SetMessageCallback(const OnMessage &cb); | |||
| void SendMessage(const CommMessage &message) const; | |||
| bool SendMessage(const CommMessage &message) const; | |||
| void StartTimer(const uint32_t &time); | |||
| void set_timer_callback(const OnTimer &timer); | |||
| const event_base &eventbase(); | |||
| @@ -57,8 +57,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| } | |||
| if (remaining_length_ == 0) { | |||
| CommMessage pb_message; | |||
| pb_message.ParseFromArray(message_buffer_.get(), message_length_); | |||
| std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>(); | |||
| pb_message->ParseFromArray(message_buffer_.get(), message_length_); | |||
| if (message_callback_) { | |||
| message_callback_(pb_message); | |||
| } | |||
| @@ -30,7 +30,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| using messageReceive = std::function<void(const CommMessage &message)>; | |||
| using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>; | |||
| constexpr int kHeaderLen = 8; | |||
| class TcpMessageHandler { | |||
| @@ -32,14 +32,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void TcpConnection::InitConnection() { | |||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | |||
| OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | |||
| if (on_server_receive) { | |||
| on_server_receive(*server_, *this, message); | |||
| } | |||
| }); | |||
| } | |||
| void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); } | |||
| void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } | |||
| @@ -49,23 +42,30 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const { | |||
| } | |||
| } | |||
| TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); } | |||
| TcpServer *TcpConnection::GetServer() const { return server_; } | |||
| const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } | |||
| void TcpConnection::SendMessage(const CommMessage &message) const { | |||
| void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; } | |||
| bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| bufferevent_lock(buffer_event_); | |||
| bool res = true; | |||
| size_t buf_size = message->ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | |||
| sizeof(buf_size)) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | |||
| message->SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||
| if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||
| res = false; | |||
| } | |||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(), | |||
| buf_size) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; | |||
| if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { | |||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||
| res = false; | |||
| } | |||
| bufferevent_unlock(buffer_event_); | |||
| return res; | |||
| } | |||
| TcpServer::TcpServer(const std::string &address, std::uint16_t port) | |||
| @@ -225,7 +225,7 @@ void TcpServer::SendToAllClients(const char *data, size_t len) { | |||
| } | |||
| } | |||
| void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { | |||
| void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) { | |||
| MS_EXCEPTION_IF_NULL(connection); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| connections_.insert(std::make_pair(fd, connection)); | |||
| @@ -233,11 +233,11 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co | |||
| void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second); | |||
| delete connection; | |||
| connections_.erase(fd); | |||
| } | |||
| std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; } | |||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, | |||
| void *data) { | |||
| auto server = reinterpret_cast<class TcpServer *>(data); | |||
| @@ -246,7 +246,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| MS_EXCEPTION_IF_NULL(base); | |||
| MS_EXCEPTION_IF_NULL(sockaddr); | |||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | |||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||
| if (!bev) { | |||
| MS_LOG(ERROR) << "Error constructing buffer event!"; | |||
| int ret = event_base_loopbreak(base); | |||
| @@ -256,23 +256,29 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| return; | |||
| } | |||
| TcpConnection *conn = server->onCreateConnection(bev, fd); | |||
| std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd); | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| conn->InitConnection(); | |||
| server->AddConnection(fd, conn); | |||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn)); | |||
| conn->InitConnection([=](std::shared_ptr<CommMessage> message) { | |||
| OnServerReceiveMessage on_server_receive = server->GetServerReceive(); | |||
| if (on_server_receive) { | |||
| on_server_receive(conn, message); | |||
| } | |||
| }); | |||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, | |||
| reinterpret_cast<void *>(conn.get())); | |||
| if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { | |||
| MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; | |||
| } | |||
| } | |||
| TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { | |||
| TcpConnection *conn = nullptr; | |||
| std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { | |||
| std::shared_ptr<TcpConnection> conn = nullptr; | |||
| if (client_accept_) { | |||
| conn = const_cast<TcpConnection *>(client_accept_(*this)); | |||
| conn = (client_accept_(*this)); | |||
| } else { | |||
| conn = new TcpConnection(bev, fd, this); | |||
| conn = std::make_shared<TcpConnection>(bev, fd, this); | |||
| } | |||
| return conn; | |||
| @@ -312,8 +318,8 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| struct evbuffer *output = bufferevent_get_output(bev); | |||
| size_t remain = evbuffer_get_length(output); | |||
| auto conn = reinterpret_cast<TcpConnection *>(data); | |||
| TcpServer *srv = conn->GetServer(); | |||
| auto conn = static_cast<class TcpConnection *>(data); | |||
| auto srv = conn->GetServer(); | |||
| if (events & BEV_EVENT_EOF) { | |||
| MS_LOG(INFO) << "Event buffer end of file!"; | |||
| @@ -355,13 +361,18 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { | |||
| } | |||
| } | |||
| void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | |||
| bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| return conn->SendMessage(message); | |||
| } | |||
| void TcpServer::SendMessage(const CommMessage &message) { | |||
| void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) { | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | |||
| SendMessage(*it->second, message); | |||
| SendMessage(it->second, message); | |||
| } | |||
| } | |||
| @@ -371,7 +382,7 @@ std::string TcpServer::BoundIp() const { return server_address_; } | |||
| int TcpServer::ConnectionNum() const { return connections_.size(); } | |||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | |||
| const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; } | |||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | |||
| @@ -34,8 +34,6 @@ | |||
| #include <thread> | |||
| #include <atomic> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/tcp_message_handler.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -47,36 +45,42 @@ namespace core { | |||
| class TcpServer; | |||
| class TcpConnection { | |||
| public: | |||
| explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) | |||
| explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, TcpServer *server) | |||
| : buffer_event_(bev), fd_(fd), server_(server) {} | |||
| TcpConnection(const TcpConnection &); | |||
| virtual ~TcpConnection() = default; | |||
| virtual void InitConnection(); | |||
| using Callback = std::function<void(const std::shared_ptr<CommMessage>)>; | |||
| virtual void InitConnection(const messageReceive &callback); | |||
| virtual void SendMessage(const void *buffer, size_t num) const; | |||
| void SendMessage(const CommMessage &message) const; | |||
| bool SendMessage(std::shared_ptr<CommMessage> message) const; | |||
| virtual void OnReadHandler(const void *buffer, size_t numBytes); | |||
| TcpServer *GetServer() const; | |||
| const evutil_socket_t &GetFd() const; | |||
| void set_callback(const Callback &callback); | |||
| protected: | |||
| struct bufferevent *buffer_event_; | |||
| evutil_socket_t fd_; | |||
| const TcpServer *server_; | |||
| TcpServer *server_; | |||
| TcpMessageHandler tcp_message_handler_; | |||
| Callback callback_; | |||
| }; | |||
| using OnServerReceiveMessage = | |||
| std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>; | |||
| std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; | |||
| class TcpServer { | |||
| public: | |||
| using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; | |||
| using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; | |||
| using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; | |||
| using OnAccepted = std::function<std::shared_ptr<TcpConnection>(const TcpServer &)>; | |||
| using OnTimerOnce = std::function<void(const TcpServer &)>; | |||
| using OnTimer = std::function<void()>; | |||
| explicit TcpServer(const std::string &address, std::uint16_t port); | |||
| TcpServer(const std::string &address, std::uint16_t port); | |||
| TcpServer(const TcpServer &server); | |||
| virtual ~TcpServer(); | |||
| void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | |||
| @@ -90,16 +94,17 @@ class TcpServer { | |||
| void StartTimer(const uint32_t &time); | |||
| void Stop(); | |||
| void SendToAllClients(const char *data, size_t len); | |||
| void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); | |||
| void AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection); | |||
| void RemoveConnection(const evutil_socket_t &fd); | |||
| std::shared_ptr<TcpConnection> GetConnectionByFd(const evutil_socket_t &fd); | |||
| OnServerReceiveMessage GetServerReceive() const; | |||
| void SetMessageCallback(const OnServerReceiveMessage &cb); | |||
| void SendMessage(const TcpConnection &conn, const CommMessage &message); | |||
| void SendMessage(const CommMessage &message); | |||
| bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||
| void SendMessage(std::shared_ptr<CommMessage> message); | |||
| uint16_t BoundPort() const; | |||
| std::string BoundIp() const; | |||
| int ConnectionNum() const; | |||
| const std::map<evutil_socket_t, const TcpConnection *> &Connections() const; | |||
| const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &Connections() const; | |||
| protected: | |||
| static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | |||
| @@ -109,7 +114,7 @@ class TcpServer { | |||
| static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | |||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | |||
| static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg); | |||
| virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | |||
| std::shared_ptr<TcpConnection> onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | |||
| struct event_base *base_; | |||
| struct event *signal_event_; | |||
| @@ -118,7 +123,7 @@ class TcpServer { | |||
| std::uint16_t server_port_; | |||
| std::atomic<bool> is_stop_; | |||
| std::map<evutil_socket_t, const TcpConnection *> connections_; | |||
| std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> connections_; | |||
| OnConnected client_connection_; | |||
| OnDisconnected client_disconnection_; | |||
| OnAccepted client_accept_; | |||
| @@ -24,8 +24,6 @@ | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| @@ -31,7 +31,7 @@ class TestClusterAvailableTimeout : public UT::Common { | |||
| }; | |||
| TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { | |||
| ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999); | |||
| ClusterConfig::Init(1, 1, "127.0.0.1", 9999); | |||
| ClusterConfig::set_cluster_available_timeout(3); | |||
| SchedulerNode node; | |||
| node.Start(); | |||
| @@ -33,7 +33,7 @@ class TestClusterConfig : public UT::Common { | |||
| }; | |||
| TEST_F(TestClusterConfig, HeartbeatInterval) { | |||
| ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080); | |||
| ClusterConfig::Init(2, 2, "127.0.0.1", 8080); | |||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); | |||
| ClusterConfig::set_heartbeat_interval(100); | |||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); | |||
| @@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { | |||
| } | |||
| TEST_F(TestCommUtil, ValidateRankId) { | |||
| ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999); | |||
| ClusterConfig::Init(3, 2, "127.0.0.1", 9999); | |||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); | |||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); | |||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | |||
| @@ -35,7 +35,7 @@ class TestTcpMessageHandler : public UT::Common { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| @@ -55,7 +55,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| @@ -86,7 +86,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); }); | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); }); | |||
| std::string data(4081, 'a'); | |||
| CommMessage message; | |||
| @@ -126,7 +126,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | |||
| TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); }); | |||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); }); | |||
| std::string data(4077, 'a'); | |||
| CommMessage message; | |||
| @@ -32,12 +32,12 @@ class TestTcpServer : public UT::Common { | |||
| void SetUp() override { | |||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 0); | |||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | |||
| http_server_thread_ = std::make_unique<std::thread>([&]() { | |||
| server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| http_server_thread_ = std::make_unique<std::thread>([=]() { | |||
| server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||
| KVMessage kv_message; | |||
| kv_message.ParseFromString(message.data()); | |||
| kv_message.ParseFromString(message->data()); | |||
| EXPECT_EQ(2, kv_message.keys_size()); | |||
| const_cast<TcpServer&>(server).SendMessage(conn, message); | |||
| server_->SendMessage(conn, message); | |||
| }); | |||
| server_->Init(); | |||
| server_->Start(); | |||
| @@ -58,6 +58,7 @@ class TestTcpServer : public UT::Common { | |||
| TEST_F(TestTcpServer, ServerSendMessage) { | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort()); | |||
| std::cout << server_->BoundPort() << std::endl; | |||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | |||
| http_client_thread = std::make_unique<std::thread>([&]() { | |||
| client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | |||