| @@ -75,6 +75,8 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string & | |||||
| auto client = GetOrCreateTcpClient((*it).first.second); | auto client = GetOrCreateTcpClient((*it).first.second); | ||||
| client->SendMessage(comm_message); | client->SendMessage(comm_message); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| return Wait(request_id, timeout); | return Wait(request_id, timeout); | ||||
| } | } | ||||
| @@ -126,11 +128,13 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | auto client = GetOrCreateTcpClient(rank_ids.at(it)); | ||||
| client->SendMessage(comm_message); | client->SendMessage(comm_message); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| return Wait(request_id, timeout); | return Wait(request_id, timeout); | ||||
| } | } | ||||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | ||||
| CommMessage *output, const uint32_t &timeout) { | |||||
| std::string *output, const uint32_t &timeout) { | |||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | if (!CommUtil::ValidateRankId(node_role, rank_id)) { | ||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | ||||
| @@ -141,7 +145,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| set_message_callback(request_id, [&]() { | set_message_callback(request_id, [&]() { | ||||
| receive_messages_mutex_.lock(); | receive_messages_mutex_.lock(); | ||||
| auto res = receive_messages_[request_id]; | auto res = receive_messages_[request_id]; | ||||
| *output = res[rank_id]; | |||||
| *output = res[rank_id].data(); | |||||
| receive_messages_.erase(request_id); | receive_messages_.erase(request_id); | ||||
| receive_messages_mutex_.unlock(); | receive_messages_mutex_.unlock(); | ||||
| }); | }); | ||||
| @@ -157,11 +161,13 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| comm_message.set_data(message); | comm_message.set_data(message); | ||||
| auto client = GetOrCreateTcpClient(rank_id); | auto client = GetOrCreateTcpClient(rank_id); | ||||
| client->SendMessage(comm_message); | client->SendMessage(comm_message); | ||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| return Wait(request_id, timeout); | return Wait(request_id, timeout); | ||||
| } | } | ||||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | ||||
| const std::vector<std::string> &data, std::vector<CommMessage> *output, | |||||
| const std::vector<std::string> &data, std::vector<std::string> *output, | |||||
| const uint32_t &timeout) { | const uint32_t &timeout) { | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| @@ -177,7 +183,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| receive_messages_mutex_.lock(); | receive_messages_mutex_.lock(); | ||||
| auto res = receive_messages_[request_id]; | auto res = receive_messages_[request_id]; | ||||
| for (size_t it = 0; it < len; ++it) { | for (size_t it = 0; it < len; ++it) { | ||||
| (*output).push_back(res[rank_ids.at(it)]); | |||||
| (*output).push_back(res[rank_ids.at(it)].data()); | |||||
| } | } | ||||
| receive_messages_.erase(request_id); | receive_messages_.erase(request_id); | ||||
| receive_messages_mutex_.unlock(); | receive_messages_mutex_.unlock(); | ||||
| @@ -201,6 +207,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| auto client = GetOrCreateTcpClient(rank_ids.at(it)); | auto client = GetOrCreateTcpClient(rank_ids.at(it)); | ||||
| client->SendMessage(comm_message); | client->SendMessage(comm_message); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| return Wait(request_id, timeout); | return Wait(request_id, timeout); | ||||
| } | } | ||||
| @@ -215,7 +223,7 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { | |||||
| } | } | ||||
| uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | ||||
| const std::string &message, const uint32_t &timeout) { | |||||
| const std::string &message) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | if (!CommUtil::ValidateRankId(node_role, rank_id)) { | ||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | ||||
| } | } | ||||
| @@ -233,19 +241,19 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const | |||||
| } | } | ||||
| std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, | std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, | ||||
| const uint32_t &rank_id, CommMessage *output) { | |||||
| const uint32_t &rank_id, std::string *output) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | if (!CommUtil::ValidateRankId(node_role, rank_id)) { | ||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | ||||
| } | } | ||||
| uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); | uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); | ||||
| if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { | if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { | ||||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)].data(); | |||||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | received_data_.erase(std::make_pair(rank_id, rank_request_id)); | ||||
| } else { | } else { | ||||
| set_receive_callback(rank_id, rank_request_id, [=]() { | set_receive_callback(rank_id, rank_request_id, [=]() { | ||||
| receive_callbacks_mutex_.lock(); | receive_callbacks_mutex_.lock(); | ||||
| *output = received_data_[std::make_pair(rank_id, 1)]; | |||||
| *output = received_data_[std::make_pair(rank_id, 1)].data(); | |||||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | received_data_.erase(std::make_pair(rank_id, rank_request_id)); | ||||
| receive_callbacks_mutex_.unlock(); | receive_callbacks_mutex_.unlock(); | ||||
| }); | }); | ||||
| @@ -272,13 +280,25 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||||
| << " begin send heartbeat to the scheduler!"; | << " begin send heartbeat to the scheduler!"; | ||||
| heart_beat_thread_ = std::make_unique<std::thread>([&]() { | heart_beat_thread_ = std::make_unique<std::thread>([&]() { | ||||
| while (!is_finish_.load()) { | while (!is_finish_.load()) { | ||||
| Heartbeat(client); | |||||
| if (!Heartbeat(client)) { | |||||
| MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||||
| if (!CheckSchedulerTimeout() && on_node_event_message_) { | |||||
| MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; | |||||
| is_finish_ = true; | |||||
| wait_finish_cond_.notify_all(); | |||||
| on_node_event_message_(NodeEvent::SCHEDULER_TIMEOUT); | |||||
| } | |||||
| } else { | |||||
| UpdateSchedulerTime(); | |||||
| } | |||||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | ||||
| } | } | ||||
| }); | }); | ||||
| } | } | ||||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||||
| bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||||
| MessageMeta meta; | MessageMeta meta; | ||||
| meta.set_cmd(NodeCommand::HEARTBEAT); | meta.set_cmd(NodeCommand::HEARTBEAT); | ||||
| @@ -292,11 +312,31 @@ void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n | |||||
| if (!SendMessageSync(client, message)) { | if (!SendMessageSync(client, message)) { | ||||
| MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | ||||
| } | } | ||||
| return true; | |||||
| } | |||||
| void AbstractNode::UpdateSchedulerTime() { | |||||
| struct timeval current_time {}; | |||||
| (void)gettimeofday(¤t_time, nullptr); | |||||
| scheduler_time_ = current_time; | |||||
| MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||||
| << " update scheduler time, the current time is: " << current_time.tv_sec; | |||||
| } | |||||
| bool AbstractNode::CheckSchedulerTimeout() const { | |||||
| struct timeval current_time {}; | |||||
| (void)gettimeofday(¤t_time, nullptr); | |||||
| if (scheduler_time_.tv_sec + ClusterConfig::scheduler_timeout() < current_time.tv_sec) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { | ||||
| HeartbeatRespMessage heartbeat_resp_message; | HeartbeatRespMessage heartbeat_resp_message; | ||||
| heartbeat_resp_message.ParseFromString(message.data()); | heartbeat_resp_message.ParseFromString(message.data()); | ||||
| is_ready_ = heartbeat_resp_message.is_cluster_ready(); | is_ready_ = heartbeat_resp_message.is_cluster_ready(); | ||||
| if (is_ready_.load()) { | if (is_ready_.load()) { | ||||
| wait_start_cond_.notify_all(); | wait_start_cond_.notify_all(); | ||||
| @@ -353,9 +393,9 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui | |||||
| *message.mutable_pb_meta() = {meta}; | *message.mutable_pb_meta() = {meta}; | ||||
| message.set_data(finish_message.SerializeAsString()); | message.set_data(finish_message.SerializeAsString()); | ||||
| if (!SendMessageSync(client, message)) { | if (!SendMessageSync(client, message)) { | ||||
| MS_LOG(EXCEPTION) << "Disconnect timeout!"; | |||||
| MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!"; | |||||
| } | } | ||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; | |||||
| return WaitForDisconnect(timeout); | return WaitForDisconnect(timeout); | ||||
| } | } | ||||
| @@ -444,6 +484,8 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con | |||||
| message_tracker_[request_id] = std::make_pair(1, 0); | message_tracker_[request_id] = std::make_pair(1, 0); | ||||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | ||||
| client->SendMessage(message); | client->SendMessage(message); | ||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| return Wait(request_id, timeout); | return Wait(request_id, timeout); | ||||
| } | } | ||||
| @@ -452,6 +494,8 @@ uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client | |||||
| message_tracker_[request_id] = std::make_pair(1, 0); | message_tracker_[request_id] = std::make_pair(1, 0); | ||||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | ||||
| client->SendMessage(message); | client->SendMessage(message); | ||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| return request_id; | return request_id; | ||||
| } | } | ||||
| @@ -460,6 +504,8 @@ void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | |||||
| const MessageMeta &message_meta = message.pb_meta(); | const MessageMeta &message_meta = message.pb_meta(); | ||||
| const uint32_t &rank_id = message_meta.rank_id(); | const uint32_t &rank_id = message_meta.rank_id(); | ||||
| const uint64_t request_id = message_meta.request_id(); | const uint64_t request_id = message_meta.request_id(); | ||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| auto it = receive_messages_.find(request_id); | auto it = receive_messages_.find(request_id); | ||||
| if (it != receive_messages_.end()) { | if (it != receive_messages_.end()) { | ||||
| it->second[rank_id] = message; | it->second[rank_id] = message; | ||||
| @@ -42,23 +42,24 @@ class AbstractNode : public Node { | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | ||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output, | |||||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | ||||
| std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message); | |||||
| std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | ||||
| CommMessage *output); | |||||
| std::string *output); | |||||
| bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| protected: | protected: | ||||
| void Register(const std::shared_ptr<TcpClient> &client); | void Register(const std::shared_ptr<TcpClient> &client); | ||||
| void ProcessRegisterResp(const CommMessage &message); | void ProcessRegisterResp(const CommMessage &message); | ||||
| void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); | void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); | ||||
| void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); | |||||
| bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); | |||||
| void UpdateSchedulerTime(); | |||||
| bool CheckSchedulerTimeout() const; | |||||
| void ProcessHeartbeatResp(const CommMessage &message); | void ProcessHeartbeatResp(const CommMessage &message); | ||||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | void FetchServers(const std::shared_ptr<TcpClient> &client); | ||||
| void ProcessFetchServersResp(const CommMessage &message); | void ProcessFetchServersResp(const CommMessage &message); | ||||
| @@ -113,6 +114,7 @@ class AbstractNode : public Node { | |||||
| // the key is rank_id, the value is rank_id's actual request_id | // the key is rank_id, the value is rank_id's actual request_id | ||||
| std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; | std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; | ||||
| std::mutex rank_request_ids_mutex; | std::mutex rank_request_ids_mutex; | ||||
| timeval scheduler_time_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -33,15 +33,17 @@ uint32_t ClusterConfig::heartbeat_timeout_ = 30; | |||||
| uint32_t ClusterConfig::cluster_available_timeout_ = 300; | uint32_t ClusterConfig::cluster_available_timeout_ = 300; | ||||
| // The timeout period for the client to connect to the server is 100ms. | // The timeout period for the client to connect to the server is 100ms. | ||||
| uint32_t ClusterConfig::connect_interval_ = 100; | uint32_t ClusterConfig::connect_interval_ = 100; | ||||
| // When the scheduler exits, the worker and server can continue to work for 5 hours | |||||
| uint32_t ClusterConfig::scheduler_timeout_ = 3600 * 5; | |||||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | |||||
| std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | |||||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, | |||||
| const uint16_t &scheduler_port) { | |||||
| worker_num_ = worker_num; | worker_num_ = worker_num; | ||||
| server_num_ = server_num; | server_num_ = server_num; | ||||
| if (!CommUtil::CheckIp(*scheduler_host.get())) { | |||||
| MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!"; | |||||
| if (!CommUtil::CheckIp(scheduler_host)) { | |||||
| MS_LOG(EXCEPTION) << "The scheduler_host:" << scheduler_host << " is illegal!"; | |||||
| } | } | ||||
| scheduler_host_ = std::move(scheduler_host); | |||||
| scheduler_host_ = std::make_unique<std::string>(scheduler_host); | |||||
| scheduler_port_ = scheduler_port; | scheduler_port_ = scheduler_port; | ||||
| } | } | ||||
| @@ -55,7 +57,7 @@ void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) { | |||||
| heartbeat_interval_ = heartbeat_interval; | heartbeat_interval_ = heartbeat_interval; | ||||
| } | } | ||||
| std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } | |||||
| std::string ClusterConfig::scheduler_host() { return *scheduler_host_; } | |||||
| uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } | uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } | ||||
| @@ -74,6 +76,10 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa | |||||
| uint32_t ClusterConfig::connect_interval() { return connect_interval_; } | uint32_t ClusterConfig::connect_interval() { return connect_interval_; } | ||||
| void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } | void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } | ||||
| uint32_t ClusterConfig::scheduler_timeout() { return scheduler_timeout_; } | |||||
| void ClusterConfig::set_scheduler_timeout(const uint32_t &scheduler_timeout) { scheduler_timeout_ = scheduler_timeout; } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,7 +30,7 @@ namespace ps { | |||||
| namespace core { | namespace core { | ||||
| class ClusterConfig { | class ClusterConfig { | ||||
| public: | public: | ||||
| static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, | |||||
| static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, | |||||
| const uint16_t &scheduler_port); | const uint16_t &scheduler_port); | ||||
| static uint32_t worker_num(); | static uint32_t worker_num(); | ||||
| static uint32_t server_num(); | static uint32_t server_num(); | ||||
| @@ -44,6 +44,8 @@ class ClusterConfig { | |||||
| static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); | static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); | ||||
| static uint32_t connect_interval(); | static uint32_t connect_interval(); | ||||
| static void set_connect_interval(const uint32_t &connect_interval); | static void set_connect_interval(const uint32_t &connect_interval); | ||||
| static uint32_t scheduler_timeout(); | |||||
| static void set_scheduler_timeout(const uint32_t &scheduler_timeout); | |||||
| private: | private: | ||||
| static uint32_t worker_num_; | static uint32_t worker_num_; | ||||
| @@ -54,6 +56,7 @@ class ClusterConfig { | |||||
| static uint32_t heartbeat_timeout_; | static uint32_t heartbeat_timeout_; | ||||
| static uint32_t cluster_available_timeout_; | static uint32_t cluster_available_timeout_; | ||||
| static uint32_t connect_interval_; | static uint32_t connect_interval_; | ||||
| static uint32_t scheduler_timeout_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -21,7 +21,12 @@ namespace ps { | |||||
| namespace core { | namespace core { | ||||
| std::string Node::node_id() const { return node_info_.node_id_; } | std::string Node::node_id() const { return node_info_.node_id_; } | ||||
| uint32_t Node::rank_id() const { return node_info_.rank_id_; } | |||||
| uint32_t Node::rank_id() const { | |||||
| if (!is_ready_.load()) { | |||||
| MS_LOG(EXCEPTION) << "The cluster is not ready yet to get rank id!"; | |||||
| } | |||||
| return node_info_.rank_id_; | |||||
| } | |||||
| NodeRole Node::role() const { return node_info_.node_role_; } | NodeRole Node::role() const { return node_info_.node_role_; } | ||||
| @@ -30,8 +30,6 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <tuple> | #include <tuple> | ||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "ps/core/node_info.h" | #include "ps/core/node_info.h" | ||||
| #include "ps/core/tcp_client.h" | #include "ps/core/tcp_client.h" | ||||
| @@ -25,7 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; | |||||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT }; | |||||
| struct NodeInfo { | struct NodeInfo { | ||||
| NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | ||||
| @@ -64,8 +64,8 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) { | |||||
| struct timeval current_time {}; | struct timeval current_time {}; | ||||
| (void)gettimeofday(¤t_time, nullptr); | (void)gettimeofday(¤t_time, nullptr); | ||||
| heartbeats_[node_id] = current_time; | heartbeats_[node_id] = current_time; | ||||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id | |||||
| << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; | |||||
| MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id | |||||
| << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; | |||||
| } | } | ||||
| void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); } | void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); } | ||||
| @@ -31,8 +31,6 @@ | |||||
| #include <condition_variable> | #include <condition_variable> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/node.h" | #include "ps/core/node.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| @@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME; | |||||
| enum PSCommand { | enum PSCommand { | ||||
| PUSH = 0; | PUSH = 0; | ||||
| PULL = 1; | PULL = 1; | ||||
| INIT_EMBEDDING_TABLE = 2; | |||||
| } | } | ||||
| message KVMessage { | message KVMessage { | ||||
| @@ -37,9 +37,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message) { | |||||
| HeartbeatMessage heartbeat_message; | HeartbeatMessage heartbeat_message; | ||||
| heartbeat_message.ParseFromString(message.data()); | |||||
| heartbeat_message.ParseFromString(message->data()); | |||||
| node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); | node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); | ||||
| @@ -59,10 +60,10 @@ void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnectio | |||||
| heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); | heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); | ||||
| heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); | heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); | ||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||||
| comm_message.set_data(heartbeat_resp_message.SerializeAsString()); | |||||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||||
| comm_message->set_data(heartbeat_resp_message.SerializeAsString()); | |||||
| server->SendMessage(conn, comm_message); | |||||
| } | } | ||||
| void SchedulerNode::Initialize() { | void SchedulerNode::Initialize() { | ||||
| @@ -79,23 +80,23 @@ void SchedulerNode::CreateTcpServer() { | |||||
| std::string scheduler_host = ClusterConfig::scheduler_host(); | std::string scheduler_host = ClusterConfig::scheduler_host(); | ||||
| uint32_t scheduler_port = ClusterConfig::scheduler_port(); | uint32_t scheduler_port = ClusterConfig::scheduler_port(); | ||||
| server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port); | |||||
| server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| switch (message.pb_meta().cmd()) { | |||||
| server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port); | |||||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||||
| switch (message->pb_meta().cmd()) { | |||||
| case NodeCommand::HEARTBEAT: | case NodeCommand::HEARTBEAT: | ||||
| ProcessHeartbeat(server, conn, message); | |||||
| ProcessHeartbeat(server_, conn, message); | |||||
| break; | break; | ||||
| case NodeCommand::REGISTER: | case NodeCommand::REGISTER: | ||||
| ProcessRegister(server, conn, message); | |||||
| ProcessRegister(server_, conn, message); | |||||
| break; | break; | ||||
| case NodeCommand::FINISH: | case NodeCommand::FINISH: | ||||
| ProcessFinish(server, conn, message); | |||||
| ProcessFinish(server_, conn, message); | |||||
| break; | break; | ||||
| case NodeCommand::FETCH_SERVER: | case NodeCommand::FETCH_SERVER: | ||||
| ProcessFetchServers(server, conn, message); | |||||
| ProcessFetchServers(server_, conn, message); | |||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||||
| MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; | |||||
| } | } | ||||
| }); | }); | ||||
| @@ -107,10 +108,11 @@ void SchedulerNode::CreateTcpServer() { | |||||
| }); | }); | ||||
| } | } | ||||
| void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message) { | |||||
| MS_LOG(INFO) << "The scheduler process a register message!"; | MS_LOG(INFO) << "The scheduler process a register message!"; | ||||
| RegisterMessage register_message; | RegisterMessage register_message; | ||||
| register_message.ParseFromString(message.data()); | |||||
| register_message.ParseFromString(message->data()); | |||||
| // assign worker node and server node rank id | // assign worker node and server node rank id | ||||
| int rank_id = node_manager_.NextRankId(register_message); | int rank_id = node_manager_.NextRankId(register_message); | ||||
| @@ -124,31 +126,32 @@ void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection | |||||
| register_resp_message.set_node_id(node_id); | register_resp_message.set_node_id(node_id); | ||||
| register_resp_message.set_rank_id(rank_id); | register_resp_message.set_rank_id(rank_id); | ||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||||
| comm_message.set_data(register_resp_message.SerializeAsString()); | |||||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||||
| comm_message->set_data(register_resp_message.SerializeAsString()); | |||||
| server->SendMessage(conn, comm_message); | |||||
| } | } | ||||
| void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message) { | |||||
| FinishMessage finish_message; | FinishMessage finish_message; | ||||
| finish_message.ParseFromString(message.data()); | |||||
| finish_message.ParseFromString(message->data()); | |||||
| node_manager_.AddFinishNode(finish_message); | node_manager_.AddFinishNode(finish_message); | ||||
| MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); | MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); | ||||
| const_cast<TcpServer &>(server).SendMessage(conn, message); | |||||
| server->SendMessage(conn, message); | |||||
| } | } | ||||
| void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, | |||||
| const CommMessage &message) { | |||||
| void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message) { | |||||
| FetchServersRespMessage fetch_servers_message; | FetchServersRespMessage fetch_servers_message; | ||||
| std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); | std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); | ||||
| *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; | *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; | ||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||||
| comm_message.set_data(fetch_servers_message.SerializeAsString()); | |||||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||||
| comm_message->set_data(fetch_servers_message.SerializeAsString()); | |||||
| server->SendMessage(conn, comm_message); | |||||
| } | } | ||||
| void SchedulerNode::StartUpdateClusterStateTimer() { | void SchedulerNode::StartUpdateClusterStateTimer() { | ||||
| @@ -26,8 +26,6 @@ | |||||
| #include <thread> | #include <thread> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "ps/core/tcp_client.h" | #include "ps/core/tcp_client.h" | ||||
| #include "ps/core/tcp_server.h" | #include "ps/core/tcp_server.h" | ||||
| @@ -51,13 +49,17 @@ class SchedulerNode : public Node { | |||||
| private: | private: | ||||
| void Initialize(); | void Initialize(); | ||||
| void CreateTcpServer(); | void CreateTcpServer(); | ||||
| void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message); | |||||
| void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message); | |||||
| void StartUpdateClusterStateTimer(); | void StartUpdateClusterStateTimer(); | ||||
| void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message); | |||||
| void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, | |||||
| std::shared_ptr<CommMessage> message); | |||||
| std::unique_ptr<TcpServer> server_; | |||||
| std::shared_ptr<TcpServer> server_; | |||||
| std::unique_ptr<std::thread> scheduler_thread_; | std::unique_ptr<std::thread> scheduler_thread_; | ||||
| std::unique_ptr<std::thread> update_state_thread_; | std::unique_ptr<std::thread> update_state_thread_; | ||||
| @@ -30,7 +30,8 @@ bool ServerNode::Start(const uint32_t &timeout) { | |||||
| StartHeartbeatTimer(client_to_scheduler_); | StartHeartbeatTimer(client_to_scheduler_); | ||||
| if (!WaitForStart(timeout)) { | if (!WaitForStart(timeout)) { | ||||
| MS_LOG(ERROR) << "Start Server node timeout!"; | |||||
| MS_LOG(ERROR) << "Start server node timeout!"; | |||||
| return false; | |||||
| } | } | ||||
| MS_LOG(INFO) << "The cluster is ready to use!"; | MS_LOG(INFO) << "The cluster is ready to use!"; | ||||
| @@ -45,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) { | |||||
| void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } | void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } | ||||
| void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||||
| const std::string &message) { | |||||
| auto &meta = const_cast<MessageMeta &>(message_meta); | |||||
| meta.set_role(node_info_.node_role_); | |||||
| meta.set_rank_id(node_info_.rank_id_); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {meta}; | |||||
| comm_message.set_data(message); | |||||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||||
| void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||||
| MS_EXCEPTION_IF_NULL(conn); | |||||
| MS_EXCEPTION_IF_NULL(message); | |||||
| message->mutable_pb_meta()->set_role(node_info_.node_role_); | |||||
| message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_); | |||||
| const MessageMeta &message_meta = message->pb_meta(); | |||||
| const uint64_t request_id = message_meta.request_id(); | |||||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||||
| server_->SendMessage(conn, message); | |||||
| } | } | ||||
| void ServerNode::CreateTcpServer() { | void ServerNode::CreateTcpServer() { | ||||
| @@ -62,17 +63,17 @@ void ServerNode::CreateTcpServer() { | |||||
| std::string server_ip; | std::string server_ip; | ||||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); | CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); | ||||
| server_ = std::make_shared<TcpServer>(server_ip, 0); | server_ = std::make_shared<TcpServer>(server_ip, 0); | ||||
| server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| switch (message.pb_meta().cmd()) { | |||||
| server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||||
| switch (message->pb_meta().cmd()) { | |||||
| case NodeCommand::SEND_DATA: | case NodeCommand::SEND_DATA: | ||||
| ProcessSendData(server, conn, message); | |||||
| ProcessSendData(conn, message); | |||||
| break; | break; | ||||
| case NodeCommand::COLLECTIVE_SEND_DATA: | case NodeCommand::COLLECTIVE_SEND_DATA: | ||||
| ProcessCollectiveSendData(server, conn, message); | |||||
| RunReceiveCallback(message); | |||||
| ProcessCollectiveSendData(conn, message); | |||||
| RunReceiveCallback(*message); | |||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||||
| MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; | |||||
| } | } | ||||
| }); | }); | ||||
| server_->Init(); | server_->Init(); | ||||
| @@ -97,15 +98,18 @@ void ServerNode::Initialize() { | |||||
| MS_LOG(INFO) << "Server node init client successful!"; | MS_LOG(INFO) << "Server node init client successful!"; | ||||
| } | } | ||||
| void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||||
| void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||||
| MS_EXCEPTION_IF_NULL(conn); | |||||
| MS_EXCEPTION_IF_NULL(message); | |||||
| request_handler_(conn, message); | |||||
| } | } | ||||
| void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, | |||||
| const CommMessage &message) { | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||||
| void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||||
| MS_EXCEPTION_IF_NULL(conn); | |||||
| MS_EXCEPTION_IF_NULL(message); | |||||
| std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); | |||||
| *comm_message->mutable_pb_meta() = {message->pb_meta()}; | |||||
| server_->SendMessage(conn, comm_message); | |||||
| } | } | ||||
| bool ServerNode::Stop() { | bool ServerNode::Stop() { | ||||
| @@ -44,18 +44,16 @@ class ServerNode : public AbstractNode { | |||||
| bool Stop() override; | bool Stop() override; | ||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | ||||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta, | |||||
| const std::string &message)>; | |||||
| using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; | |||||
| void set_handler(const RequestHandler &handler); | void set_handler(const RequestHandler &handler); | ||||
| void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||||
| const std::string &message); | |||||
| void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||||
| private: | private: | ||||
| void CreateTcpServer(); | void CreateTcpServer(); | ||||
| void Initialize(); | void Initialize(); | ||||
| void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||||
| void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||||
| std::shared_ptr<TcpServer> server_; | std::shared_ptr<TcpServer> server_; | ||||
| std::unique_ptr<std::thread> server_thread_; | std::unique_ptr<std::thread> server_thread_; | ||||
| @@ -46,9 +46,9 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||||
| server_port_(port), | server_port_(port), | ||||
| is_stop_(true), | is_stop_(true), | ||||
| is_connected_(false) { | is_connected_(false) { | ||||
| message_handler_.SetCallback([this](const CommMessage &message) { | |||||
| message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) { | |||||
| if (message_callback_) { | if (message_callback_) { | ||||
| message_callback_(*this, message); | |||||
| message_callback_(*this, *message); | |||||
| } | } | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -105,7 +105,7 @@ void TcpClient::Init() { | |||||
| sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); | sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); | ||||
| sin.sin_port = htons(server_port_); | sin.sin_port = htons(server_port_); | ||||
| buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); | |||||
| buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | MS_EXCEPTION_IF_NULL(buffer_event_); | ||||
| bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); | bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); | ||||
| @@ -261,17 +261,23 @@ void TcpClient::StartWithNoBlock() { | |||||
| void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } | void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } | ||||
| void TcpClient::SendMessage(const CommMessage &message) const { | |||||
| bool TcpClient::SendMessage(const CommMessage &message) const { | |||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | MS_EXCEPTION_IF_NULL(buffer_event_); | ||||
| bufferevent_lock(buffer_event_); | |||||
| bool res = true; | |||||
| size_t buf_size = message.ByteSizeLong(); | size_t buf_size = message.ByteSizeLong(); | ||||
| std::vector<unsigned char> serialized(buf_size); | std::vector<unsigned char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | ||||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | |||||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | |||||
| if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { | |||||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||||
| res = false; | |||||
| } | } | ||||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) { | |||||
| MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; | |||||
| if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { | |||||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||||
| res = false; | |||||
| } | } | ||||
| bufferevent_unlock(buffer_event_); | |||||
| return res; | |||||
| } | } | ||||
| void TcpClient::StartTimer(const uint32_t &time) { | void TcpClient::StartTimer(const uint32_t &time) { | ||||
| @@ -33,8 +33,6 @@ | |||||
| #include <condition_variable> | #include <condition_variable> | ||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -62,7 +60,7 @@ class TcpClient { | |||||
| void Start(); | void Start(); | ||||
| void StartWithNoBlock(); | void StartWithNoBlock(); | ||||
| void SetMessageCallback(const OnMessage &cb); | void SetMessageCallback(const OnMessage &cb); | ||||
| void SendMessage(const CommMessage &message) const; | |||||
| bool SendMessage(const CommMessage &message) const; | |||||
| void StartTimer(const uint32_t &time); | void StartTimer(const uint32_t &time); | ||||
| void set_timer_callback(const OnTimer &timer); | void set_timer_callback(const OnTimer &timer); | ||||
| const event_base &eventbase(); | const event_base &eventbase(); | ||||
| @@ -57,8 +57,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| } | } | ||||
| if (remaining_length_ == 0) { | if (remaining_length_ == 0) { | ||||
| CommMessage pb_message; | |||||
| pb_message.ParseFromArray(message_buffer_.get(), message_length_); | |||||
| std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>(); | |||||
| pb_message->ParseFromArray(message_buffer_.get(), message_length_); | |||||
| if (message_callback_) { | if (message_callback_) { | ||||
| message_callback_(pb_message); | message_callback_(pb_message); | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| using messageReceive = std::function<void(const CommMessage &message)>; | |||||
| using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>; | |||||
| constexpr int kHeaderLen = 8; | constexpr int kHeaderLen = 8; | ||||
| class TcpMessageHandler { | class TcpMessageHandler { | ||||
| @@ -32,14 +32,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| void TcpConnection::InitConnection() { | |||||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | |||||
| OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | |||||
| if (on_server_receive) { | |||||
| on_server_receive(*server_, *this, message); | |||||
| } | |||||
| }); | |||||
| } | |||||
| void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); } | |||||
| void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } | void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } | ||||
| @@ -49,23 +42,30 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const { | |||||
| } | } | ||||
| } | } | ||||
| TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); } | |||||
| TcpServer *TcpConnection::GetServer() const { return server_; } | |||||
| const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } | const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } | ||||
| void TcpConnection::SendMessage(const CommMessage &message) const { | |||||
| void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; } | |||||
| bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const { | |||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | MS_EXCEPTION_IF_NULL(buffer_event_); | ||||
| size_t buf_size = message.ByteSizeLong(); | |||||
| MS_EXCEPTION_IF_NULL(message); | |||||
| bufferevent_lock(buffer_event_); | |||||
| bool res = true; | |||||
| size_t buf_size = message->ByteSizeLong(); | |||||
| std::vector<unsigned char> serialized(buf_size); | std::vector<unsigned char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | |||||
| sizeof(buf_size)) == -1) { | |||||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | |||||
| message->SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||||
| if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { | |||||
| MS_LOG(ERROR) << "Event buffer add header failed!"; | |||||
| res = false; | |||||
| } | } | ||||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(), | |||||
| buf_size) == -1) { | |||||
| MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; | |||||
| if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { | |||||
| MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; | |||||
| res = false; | |||||
| } | } | ||||
| bufferevent_unlock(buffer_event_); | |||||
| return res; | |||||
| } | } | ||||
| TcpServer::TcpServer(const std::string &address, std::uint16_t port) | TcpServer::TcpServer(const std::string &address, std::uint16_t port) | ||||
| @@ -225,7 +225,7 @@ void TcpServer::SendToAllClients(const char *data, size_t len) { | |||||
| } | } | ||||
| } | } | ||||
| void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { | |||||
| void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) { | |||||
| MS_EXCEPTION_IF_NULL(connection); | MS_EXCEPTION_IF_NULL(connection); | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | std::lock_guard<std::mutex> lock(connection_mutex_); | ||||
| connections_.insert(std::make_pair(fd, connection)); | connections_.insert(std::make_pair(fd, connection)); | ||||
| @@ -233,11 +233,11 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co | |||||
| void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | std::lock_guard<std::mutex> lock(connection_mutex_); | ||||
| TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second); | |||||
| delete connection; | |||||
| connections_.erase(fd); | connections_.erase(fd); | ||||
| } | } | ||||
| std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; } | |||||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, | void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, | ||||
| void *data) { | void *data) { | ||||
| auto server = reinterpret_cast<class TcpServer *>(data); | auto server = reinterpret_cast<class TcpServer *>(data); | ||||
| @@ -246,7 +246,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||||
| MS_EXCEPTION_IF_NULL(base); | MS_EXCEPTION_IF_NULL(base); | ||||
| MS_EXCEPTION_IF_NULL(sockaddr); | MS_EXCEPTION_IF_NULL(sockaddr); | ||||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | |||||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||||
| if (!bev) { | if (!bev) { | ||||
| MS_LOG(ERROR) << "Error constructing buffer event!"; | MS_LOG(ERROR) << "Error constructing buffer event!"; | ||||
| int ret = event_base_loopbreak(base); | int ret = event_base_loopbreak(base); | ||||
| @@ -256,23 +256,29 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||||
| return; | return; | ||||
| } | } | ||||
| TcpConnection *conn = server->onCreateConnection(bev, fd); | |||||
| std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd); | |||||
| MS_EXCEPTION_IF_NULL(conn); | MS_EXCEPTION_IF_NULL(conn); | ||||
| conn->InitConnection(); | |||||
| server->AddConnection(fd, conn); | server->AddConnection(fd, conn); | ||||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn)); | |||||
| conn->InitConnection([=](std::shared_ptr<CommMessage> message) { | |||||
| OnServerReceiveMessage on_server_receive = server->GetServerReceive(); | |||||
| if (on_server_receive) { | |||||
| on_server_receive(conn, message); | |||||
| } | |||||
| }); | |||||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, | |||||
| reinterpret_cast<void *>(conn.get())); | |||||
| if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { | if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { | ||||
| MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; | MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; | ||||
| } | } | ||||
| } | } | ||||
| TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { | |||||
| TcpConnection *conn = nullptr; | |||||
| std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { | |||||
| std::shared_ptr<TcpConnection> conn = nullptr; | |||||
| if (client_accept_) { | if (client_accept_) { | ||||
| conn = const_cast<TcpConnection *>(client_accept_(*this)); | |||||
| conn = (client_accept_(*this)); | |||||
| } else { | } else { | ||||
| conn = new TcpConnection(bev, fd, this); | |||||
| conn = std::make_shared<TcpConnection>(bev, fd, this); | |||||
| } | } | ||||
| return conn; | return conn; | ||||
| @@ -312,8 +318,8 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||||
| MS_EXCEPTION_IF_NULL(data); | MS_EXCEPTION_IF_NULL(data); | ||||
| struct evbuffer *output = bufferevent_get_output(bev); | struct evbuffer *output = bufferevent_get_output(bev); | ||||
| size_t remain = evbuffer_get_length(output); | size_t remain = evbuffer_get_length(output); | ||||
| auto conn = reinterpret_cast<TcpConnection *>(data); | |||||
| TcpServer *srv = conn->GetServer(); | |||||
| auto conn = static_cast<class TcpConnection *>(data); | |||||
| auto srv = conn->GetServer(); | |||||
| if (events & BEV_EVENT_EOF) { | if (events & BEV_EVENT_EOF) { | ||||
| MS_LOG(INFO) << "Event buffer end of file!"; | MS_LOG(INFO) << "Event buffer end of file!"; | ||||
| @@ -355,13 +361,18 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { | |||||
| } | } | ||||
| } | } | ||||
| void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | |||||
| bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||||
| MS_EXCEPTION_IF_NULL(conn); | |||||
| MS_EXCEPTION_IF_NULL(message); | |||||
| return conn->SendMessage(message); | |||||
| } | |||||
| void TcpServer::SendMessage(const CommMessage &message) { | |||||
| void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) { | |||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | std::lock_guard<std::mutex> lock(connection_mutex_); | ||||
| MS_EXCEPTION_IF_NULL(message); | |||||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | for (auto it = connections_.begin(); it != connections_.end(); ++it) { | ||||
| SendMessage(*it->second, message); | |||||
| SendMessage(it->second, message); | |||||
| } | } | ||||
| } | } | ||||
| @@ -371,7 +382,7 @@ std::string TcpServer::BoundIp() const { return server_address_; } | |||||
| int TcpServer::ConnectionNum() const { return connections_.size(); } | int TcpServer::ConnectionNum() const { return connections_.size(); } | ||||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | |||||
| const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; } | |||||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | ||||
| @@ -34,8 +34,6 @@ | |||||
| #include <thread> | #include <thread> | ||||
| #include <atomic> | #include <atomic> | ||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/tcp_message_handler.h" | #include "ps/core/tcp_message_handler.h" | ||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -47,36 +45,42 @@ namespace core { | |||||
| class TcpServer; | class TcpServer; | ||||
| class TcpConnection { | class TcpConnection { | ||||
| public: | public: | ||||
| explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) | |||||
| explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, TcpServer *server) | |||||
| : buffer_event_(bev), fd_(fd), server_(server) {} | : buffer_event_(bev), fd_(fd), server_(server) {} | ||||
| TcpConnection(const TcpConnection &); | |||||
| virtual ~TcpConnection() = default; | virtual ~TcpConnection() = default; | ||||
| virtual void InitConnection(); | |||||
| using Callback = std::function<void(const std::shared_ptr<CommMessage>)>; | |||||
| virtual void InitConnection(const messageReceive &callback); | |||||
| virtual void SendMessage(const void *buffer, size_t num) const; | virtual void SendMessage(const void *buffer, size_t num) const; | ||||
| void SendMessage(const CommMessage &message) const; | |||||
| bool SendMessage(std::shared_ptr<CommMessage> message) const; | |||||
| virtual void OnReadHandler(const void *buffer, size_t numBytes); | virtual void OnReadHandler(const void *buffer, size_t numBytes); | ||||
| TcpServer *GetServer() const; | TcpServer *GetServer() const; | ||||
| const evutil_socket_t &GetFd() const; | const evutil_socket_t &GetFd() const; | ||||
| void set_callback(const Callback &callback); | |||||
| protected: | protected: | ||||
| struct bufferevent *buffer_event_; | struct bufferevent *buffer_event_; | ||||
| evutil_socket_t fd_; | evutil_socket_t fd_; | ||||
| const TcpServer *server_; | |||||
| TcpServer *server_; | |||||
| TcpMessageHandler tcp_message_handler_; | TcpMessageHandler tcp_message_handler_; | ||||
| Callback callback_; | |||||
| }; | }; | ||||
| using OnServerReceiveMessage = | using OnServerReceiveMessage = | ||||
| std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>; | |||||
| std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; | |||||
| class TcpServer { | class TcpServer { | ||||
| public: | public: | ||||
| using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; | using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; | ||||
| using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; | using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; | ||||
| using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; | |||||
| using OnAccepted = std::function<std::shared_ptr<TcpConnection>(const TcpServer &)>; | |||||
| using OnTimerOnce = std::function<void(const TcpServer &)>; | using OnTimerOnce = std::function<void(const TcpServer &)>; | ||||
| using OnTimer = std::function<void()>; | using OnTimer = std::function<void()>; | ||||
| explicit TcpServer(const std::string &address, std::uint16_t port); | |||||
| TcpServer(const std::string &address, std::uint16_t port); | |||||
| TcpServer(const TcpServer &server); | |||||
| virtual ~TcpServer(); | virtual ~TcpServer(); | ||||
| void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | ||||
| @@ -90,16 +94,17 @@ class TcpServer { | |||||
| void StartTimer(const uint32_t &time); | void StartTimer(const uint32_t &time); | ||||
| void Stop(); | void Stop(); | ||||
| void SendToAllClients(const char *data, size_t len); | void SendToAllClients(const char *data, size_t len); | ||||
| void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); | |||||
| void AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection); | |||||
| void RemoveConnection(const evutil_socket_t &fd); | void RemoveConnection(const evutil_socket_t &fd); | ||||
| std::shared_ptr<TcpConnection> GetConnectionByFd(const evutil_socket_t &fd); | |||||
| OnServerReceiveMessage GetServerReceive() const; | OnServerReceiveMessage GetServerReceive() const; | ||||
| void SetMessageCallback(const OnServerReceiveMessage &cb); | void SetMessageCallback(const OnServerReceiveMessage &cb); | ||||
| void SendMessage(const TcpConnection &conn, const CommMessage &message); | |||||
| void SendMessage(const CommMessage &message); | |||||
| bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); | |||||
| void SendMessage(std::shared_ptr<CommMessage> message); | |||||
| uint16_t BoundPort() const; | uint16_t BoundPort() const; | ||||
| std::string BoundIp() const; | std::string BoundIp() const; | ||||
| int ConnectionNum() const; | int ConnectionNum() const; | ||||
| const std::map<evutil_socket_t, const TcpConnection *> &Connections() const; | |||||
| const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &Connections() const; | |||||
| protected: | protected: | ||||
| static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | ||||
| @@ -109,7 +114,7 @@ class TcpServer { | |||||
| static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | ||||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | ||||
| static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg); | static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg); | ||||
| virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | |||||
| std::shared_ptr<TcpConnection> onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | |||||
| struct event_base *base_; | struct event_base *base_; | ||||
| struct event *signal_event_; | struct event *signal_event_; | ||||
| @@ -118,7 +123,7 @@ class TcpServer { | |||||
| std::uint16_t server_port_; | std::uint16_t server_port_; | ||||
| std::atomic<bool> is_stop_; | std::atomic<bool> is_stop_; | ||||
| std::map<evutil_socket_t, const TcpConnection *> connections_; | |||||
| std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> connections_; | |||||
| OnConnected client_connection_; | OnConnected client_connection_; | ||||
| OnDisconnected client_disconnection_; | OnDisconnected client_disconnection_; | ||||
| OnAccepted client_accept_; | OnAccepted client_accept_; | ||||
| @@ -24,8 +24,6 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "proto/comm.pb.h" | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "ps/core/tcp_client.h" | #include "ps/core/tcp_client.h" | ||||
| #include "ps/core/tcp_server.h" | #include "ps/core/tcp_server.h" | ||||
| @@ -31,7 +31,7 @@ class TestClusterAvailableTimeout : public UT::Common { | |||||
| }; | }; | ||||
| TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { | TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { | ||||
| ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999); | |||||
| ClusterConfig::Init(1, 1, "127.0.0.1", 9999); | |||||
| ClusterConfig::set_cluster_available_timeout(3); | ClusterConfig::set_cluster_available_timeout(3); | ||||
| SchedulerNode node; | SchedulerNode node; | ||||
| node.Start(); | node.Start(); | ||||
| @@ -33,7 +33,7 @@ class TestClusterConfig : public UT::Common { | |||||
| }; | }; | ||||
| TEST_F(TestClusterConfig, HeartbeatInterval) { | TEST_F(TestClusterConfig, HeartbeatInterval) { | ||||
| ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080); | |||||
| ClusterConfig::Init(2, 2, "127.0.0.1", 8080); | |||||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); | EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); | ||||
| ClusterConfig::set_heartbeat_interval(100); | ClusterConfig::set_heartbeat_interval(100); | ||||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); | EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); | ||||
| @@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { | |||||
| } | } | ||||
| TEST_F(TestCommUtil, ValidateRankId) { | TEST_F(TestCommUtil, ValidateRankId) { | ||||
| ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999); | |||||
| ClusterConfig::Init(3, 2, "127.0.0.1", 9999); | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); | EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); | ||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); | EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); | ||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | ||||
| @@ -35,7 +35,7 @@ class TestTcpMessageHandler : public UT::Common { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | ||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); | |||||
| std::string data(1000, 'a'); | std::string data(1000, 'a'); | ||||
| CommMessage message; | CommMessage message; | ||||
| @@ -55,7 +55,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | ||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); | |||||
| std::string data(1000, 'a'); | std::string data(1000, 'a'); | ||||
| CommMessage message; | CommMessage message; | ||||
| @@ -86,7 +86,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | ||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); }); | |||||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); }); | |||||
| std::string data(4081, 'a'); | std::string data(4081, 'a'); | ||||
| CommMessage message; | CommMessage message; | ||||
| @@ -126,7 +126,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { | |||||
| TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { | TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { | ||||
| TcpMessageHandler handler; | TcpMessageHandler handler; | ||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); }); | |||||
| handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); }); | |||||
| std::string data(4077, 'a'); | std::string data(4077, 'a'); | ||||
| CommMessage message; | CommMessage message; | ||||
| @@ -32,12 +32,12 @@ class TestTcpServer : public UT::Common { | |||||
| void SetUp() override { | void SetUp() override { | ||||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 0); | server_ = std::make_unique<TcpServer>("127.0.0.1", 0); | ||||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | std::unique_ptr<std::thread> http_server_thread_(nullptr); | ||||
| http_server_thread_ = std::make_unique<std::thread>([&]() { | |||||
| server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||||
| http_server_thread_ = std::make_unique<std::thread>([=]() { | |||||
| server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { | |||||
| KVMessage kv_message; | KVMessage kv_message; | ||||
| kv_message.ParseFromString(message.data()); | |||||
| kv_message.ParseFromString(message->data()); | |||||
| EXPECT_EQ(2, kv_message.keys_size()); | EXPECT_EQ(2, kv_message.keys_size()); | ||||
| const_cast<TcpServer&>(server).SendMessage(conn, message); | |||||
| server_->SendMessage(conn, message); | |||||
| }); | }); | ||||
| server_->Init(); | server_->Init(); | ||||
| server_->Start(); | server_->Start(); | ||||
| @@ -58,6 +58,7 @@ class TestTcpServer : public UT::Common { | |||||
| TEST_F(TestTcpServer, ServerSendMessage) { | TEST_F(TestTcpServer, ServerSendMessage) { | ||||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort()); | client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort()); | ||||
| std::cout << server_->BoundPort() << std::endl; | |||||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | std::unique_ptr<std::thread> http_client_thread(nullptr); | ||||
| http_client_thread = std::make_unique<std::thread>([&]() { | http_client_thread = std::make_unique<std::thread>([&]() { | ||||
| client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | ||||