From: @anancds Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -53,13 +53,20 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | |||||
| MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | ||||
| } | } | ||||
| bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) { | |||||
| bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) { | |||||
| if (node_role != NodeRole::SERVER) { | |||||
| MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; | |||||
| } | |||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | ||||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | ||||
| MessageMeta message_meta; | MessageMeta message_meta; | ||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | message_meta.set_cmd(NodeCommand::SEND_DATA); | ||||
| message_meta.set_request_id(request_id); | message_meta.set_request_id(request_id); | ||||
| message_meta.set_rank_id(node_info_.rank_id_); | |||||
| message_meta.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | CommMessage comm_message; | ||||
| *comm_message.mutable_pb_meta() = {message_meta}; | *comm_message.mutable_pb_meta() = {message_meta}; | ||||
| @@ -82,12 +89,14 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| MessageMeta message_meta; | MessageMeta message_meta; | ||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | message_meta.set_cmd(NodeCommand::SEND_DATA); | ||||
| message_meta.set_rank_id(node_info_.rank_id_); | |||||
| message_meta.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | CommMessage comm_message; | ||||
| *comm_message.mutable_pb_meta() = {message_meta}; | *comm_message.mutable_pb_meta() = {message_meta}; | ||||
| comm_message.set_data(message); | comm_message.set_data(message); | ||||
| auto client = GetOrCreateTcpClient(rank_id); | auto client = GetOrCreateTcpClient(rank_id); | ||||
| return SendMessageSync(client, comm_message); | |||||
| return SendMessageSync(client, comm_message, timeout); | |||||
| } | } | ||||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | ||||
| @@ -106,6 +115,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| MessageMeta message_meta; | MessageMeta message_meta; | ||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | message_meta.set_cmd(NodeCommand::SEND_DATA); | ||||
| message_meta.set_request_id(request_id); | message_meta.set_request_id(request_id); | ||||
| message_meta.set_rank_id(node_info_.rank_id_); | |||||
| message_meta.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | CommMessage comm_message; | ||||
| *comm_message.mutable_pb_meta() = {message_meta}; | *comm_message.mutable_pb_meta() = {message_meta}; | ||||
| @@ -118,8 +129,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| } | } | ||||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | ||||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | |||||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||||
| CommMessage *output, const uint32_t &timeout) { | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | if (!CommUtil::ValidateRankId(node_role, rank_id)) { | ||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | ||||
| } | } | ||||
| @@ -129,7 +140,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| set_message_callback(request_id, [&]() { | set_message_callback(request_id, [&]() { | ||||
| receive_messages_mutex_.lock(); | receive_messages_mutex_.lock(); | ||||
| auto res = receive_messages_[request_id]; | auto res = receive_messages_[request_id]; | ||||
| *comm_message_resp = res[rank_id]; | |||||
| *output = res[rank_id]; | |||||
| receive_messages_.erase(request_id); | receive_messages_.erase(request_id); | ||||
| receive_messages_mutex_.unlock(); | receive_messages_mutex_.unlock(); | ||||
| }); | }); | ||||
| @@ -149,9 +160,9 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| } | } | ||||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | ||||
| const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp, | |||||
| const std::vector<std::string> &data, std::vector<CommMessage> *output, | |||||
| const uint32_t &timeout) { | const uint32_t &timeout) { | ||||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | message_tracker_[request_id] = std::make_pair(data.size(), 0); | ||||
| @@ -165,7 +176,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| receive_messages_mutex_.lock(); | receive_messages_mutex_.lock(); | ||||
| auto res = receive_messages_[request_id]; | auto res = receive_messages_[request_id]; | ||||
| for (size_t it = 0; it < len; ++it) { | for (size_t it = 0; it < len; ++it) { | ||||
| (*comm_message_resp).push_back(res[rank_ids.at(it)]); | |||||
| (*output).push_back(res[rank_ids.at(it)]); | |||||
| } | } | ||||
| receive_messages_.erase(request_id); | receive_messages_.erase(request_id); | ||||
| receive_messages_mutex_.unlock(); | receive_messages_mutex_.unlock(); | ||||
| @@ -179,6 +190,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| MessageMeta message_meta; | MessageMeta message_meta; | ||||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | message_meta.set_cmd(NodeCommand::SEND_DATA); | ||||
| message_meta.set_request_id(request_id); | message_meta.set_request_id(request_id); | ||||
| message_meta.set_rank_id(node_info_.rank_id_); | |||||
| message_meta.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | CommMessage comm_message; | ||||
| *comm_message.mutable_pb_meta() = {message_meta}; | *comm_message.mutable_pb_meta() = {message_meta}; | ||||
| @@ -200,6 +213,58 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { | |||||
| return res; | return res; | ||||
| } | } | ||||
| uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| const std::string &message, const uint32_t &timeout) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||||
| } | |||||
| MessageMeta message_meta; | |||||
| message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); | |||||
| message_meta.set_rank_id(node_info_.rank_id_); | |||||
| message_meta.set_role(node_info_.node_role_); | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||||
| comm_message.set_data(message); | |||||
| auto client = GetOrCreateTcpClient(rank_id); | |||||
| return SendMessageAsync(client, comm_message); | |||||
| } | |||||
| std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, | |||||
| const uint32_t &rank_id, CommMessage *output) { | |||||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||||
| } | |||||
| uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); | |||||
| if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { | |||||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||||
| } else { | |||||
| set_receive_callback(rank_id, rank_request_id, [=]() { | |||||
| receive_callbacks_mutex_.lock(); | |||||
| *output = received_data_[std::make_pair(rank_id, 1)]; | |||||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||||
| receive_callbacks_mutex_.unlock(); | |||||
| }); | |||||
| } | |||||
| return std::make_pair(rank_id, rank_request_id); | |||||
| } | |||||
| bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) { | |||||
| std::unique_lock<std::mutex> lock(receive_callbacks_mutex_); | |||||
| bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||||
| if (actual_rank_request_ids_.count(request_id.first) && | |||||
| (actual_rank_request_ids_[request_id.first] >= request_id.second)) { | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| }); | |||||
| return res; | |||||
| } | |||||
| void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) { | void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) { | ||||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | ||||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | ||||
| @@ -210,7 +275,6 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | ||||
| } | } | ||||
| }); | }); | ||||
| heart_beat_thread_->detach(); | |||||
| } | } | ||||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | ||||
| @@ -334,11 +398,9 @@ bool AbstractNode::InitClientToScheduler() { | |||||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | MS_LOG(INFO) << "The worker node start a tcp client!"; | ||||
| client_to_scheduler_->Start(); | client_to_scheduler_->Start(); | ||||
| }); | }); | ||||
| client_to_scheduler_thread_->detach(); | |||||
| client_to_scheduler_->set_disconnected_callback([&]() { | client_to_scheduler_->set_disconnected_callback([&]() { | ||||
| std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval())); | std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval())); | ||||
| client_to_scheduler_->Stop(); | |||||
| client_to_scheduler_->Init(); | client_to_scheduler_->Init(); | ||||
| }); | }); | ||||
| return client_to_scheduler_->WaitConnected(); | return client_to_scheduler_->WaitConnected(); | ||||
| @@ -361,6 +423,9 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int & | |||||
| ProcessSendDataResp(message); | ProcessSendDataResp(message); | ||||
| RunMessageCallback(message.pb_meta().request_id()); | RunMessageCallback(message.pb_meta().request_id()); | ||||
| break; | break; | ||||
| case NodeCommand::COLLECTIVE_SEND_DATA: | |||||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | ||||
| } | } | ||||
| @@ -381,10 +446,12 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con | |||||
| return Wait(request_id, timeout); | return Wait(request_id, timeout); | ||||
| } | } | ||||
| void AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||||
| uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | ||||
| client->SendMessage(message); | client->SendMessage(message); | ||||
| return request_id; | |||||
| } | } | ||||
| void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | ||||
| @@ -422,12 +489,12 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) { | |||||
| message_callbacks_mutex_.unlock(); | message_callbacks_mutex_.unlock(); | ||||
| } | } | ||||
| void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { | |||||
| if (!message_callback) { | |||||
| void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) { | |||||
| if (!callback) { | |||||
| return; | return; | ||||
| } | } | ||||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | ||||
| message_callbacks_[request_id] = message_callback; | |||||
| message_callbacks_[request_id] = callback; | |||||
| } | } | ||||
| void AbstractNode::NotifyMessageArrival(const CommMessage &message) { | void AbstractNode::NotifyMessageArrival(const CommMessage &message) { | ||||
| @@ -438,6 +505,61 @@ void AbstractNode::NotifyMessageArrival(const CommMessage &message) { | |||||
| message_tracker_[request_id].second++; | message_tracker_[request_id].second++; | ||||
| message_tracker_cond_.notify_all(); | message_tracker_cond_.notify_all(); | ||||
| } | } | ||||
| void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, | |||||
| const MessageCallback &callback) { | |||||
| if (!callback) { | |||||
| return; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(receive_callbacks_mutex_); | |||||
| receive_callbacks_[std::make_pair(rank_id, request_id)] = callback; | |||||
| } | |||||
| void AbstractNode::RunReceiveCallback(const CommMessage &message) { | |||||
| receive_callbacks_mutex_.lock(); | |||||
| uint32_t rank_id = message.pb_meta().rank_id(); | |||||
| // When receiving a collective message, Then generate rank request id,compare with the desired rank request id, | |||||
| // If they are equal, then call the callback function | |||||
| uint64_t rank_request_id = NextActualRankRequestId(rank_id); | |||||
| received_data_[std::make_pair(rank_id, rank_request_id)] = message; | |||||
| auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id)); | |||||
| if (it != receive_callbacks_.end()) { | |||||
| receive_callbacks_mutex_.unlock(); | |||||
| if (it->second) { | |||||
| it->second(); | |||||
| } | |||||
| receive_callbacks_mutex_.lock(); | |||||
| receive_cond_.notify_all(); | |||||
| receive_callbacks_.erase(it); | |||||
| } | |||||
| receive_callbacks_mutex_.unlock(); | |||||
| } | |||||
| uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) { | |||||
| std::lock_guard<std::mutex> lock(rank_request_ids_mutex); | |||||
| uint64_t rank_request_id = 1; | |||||
| if (expected_rank_request_ids_.count(rank_id)) { | |||||
| rank_request_id = ++expected_rank_request_ids_[rank_id]; | |||||
| expected_rank_request_ids_[rank_id] = rank_request_id; | |||||
| } else { | |||||
| expected_rank_request_ids_[rank_id] = rank_request_id; | |||||
| } | |||||
| return rank_request_id; | |||||
| } | |||||
| uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) { | |||||
| std::lock_guard<std::mutex> lock(rank_request_ids_mutex); | |||||
| uint64_t rank_request_id = 1; | |||||
| if (actual_rank_request_ids_.count(rank_id)) { | |||||
| rank_request_id = ++actual_rank_request_ids_[rank_id]; | |||||
| actual_rank_request_ids_[rank_id] = rank_request_id; | |||||
| } else { | |||||
| actual_rank_request_ids_[rank_id] = rank_request_id; | |||||
| } | |||||
| return rank_request_id; | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,21 +34,26 @@ class AbstractNode : public Node { | |||||
| AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} | AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} | ||||
| ~AbstractNode() override = default; | ~AbstractNode() override = default; | ||||
| bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| bool Broadcast(const enum NodeRole &node_role, const std::string &message, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| void set_event_callback(const OnNodeEventMessage &on_node_event_message); | void set_event_callback(const OnNodeEventMessage &on_node_event_message); | ||||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||||
| const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||||
| std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| CommMessage *output); | |||||
| bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||||
| protected: | protected: | ||||
| void Register(const std::shared_ptr<TcpClient> &client); | void Register(const std::shared_ptr<TcpClient> &client); | ||||
| void ProcessRegisterResp(const CommMessage &message); | void ProcessRegisterResp(const CommMessage &message); | ||||
| @@ -63,34 +68,51 @@ class AbstractNode : public Node { | |||||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | ||||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | ||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||||
| uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||||
| void ProcessSendDataResp(const CommMessage &message); | void ProcessSendDataResp(const CommMessage &message); | ||||
| void RunMessageCallback(const uint64_t &request_id); | void RunMessageCallback(const uint64_t &request_id); | ||||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | |||||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); | |||||
| void NotifyMessageArrival(const CommMessage &message); | void NotifyMessageArrival(const CommMessage &message); | ||||
| void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback); | |||||
| void RunReceiveCallback(const CommMessage &message); | |||||
| uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); | |||||
| uint64_t NextActualRankRequestId(const uint32_t &rank_id); | |||||
| std::unique_ptr<std::thread> heart_beat_thread_; | std::unique_ptr<std::thread> heart_beat_thread_; | ||||
| std::unique_ptr<std::thread> client_to_scheduler_thread_; | std::unique_ptr<std::thread> client_to_scheduler_thread_; | ||||
| std::shared_ptr<TcpClient> client_to_scheduler_; | std::shared_ptr<TcpClient> client_to_scheduler_; | ||||
| OnNodeEventMessage on_node_event_message_; | OnNodeEventMessage on_node_event_message_; | ||||
| // the map's key is: <node_role,rank_id>, the map's value is: <ip, port> | |||||
| // the key is: <node_role,rank_id>, the value is: <ip, port> | |||||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | ||||
| std::mutex client_mutex_; | std::mutex client_mutex_; | ||||
| // the map's key is: rank_id | // the map's key is: rank_id | ||||
| std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_; | std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_; | ||||
| // the map's key is: request_id, the map's value is: <expected responses, actual responses> | |||||
| // the key is: request_id, the value is: <expected responses, actual responses> | |||||
| std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | ||||
| std::mutex message_tracker_mutex_; | std::mutex message_tracker_mutex_; | ||||
| std::condition_variable message_tracker_cond_; | std::condition_variable message_tracker_cond_; | ||||
| // the map's key is: request_id, the map's value is:<rank_id, CommMessage> | |||||
| // the key is: request_id, the value is:<rank_id, CommMessage> | |||||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | ||||
| std::mutex receive_messages_mutex_; | std::mutex receive_messages_mutex_; | ||||
| // the map's key is: request_id | |||||
| // the key is: request_id | |||||
| std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | ||||
| std::mutex message_callbacks_mutex_; | std::mutex message_callbacks_mutex_; | ||||
| // the key is <rank_id, rank_request_id> | |||||
| std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_; | |||||
| std::mutex receive_callbacks_mutex_; | |||||
| // the key is <rank_id, rank_request_id> | |||||
| std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_; | |||||
| std::condition_variable receive_cond_; | |||||
| // the key is rank_id, the value is rank_id's expected request_id | |||||
| std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_; | |||||
| // the key is rank_id, the value is rank_id's actual request_id | |||||
| std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; | |||||
| std::mutex rank_request_ids_mutex; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -26,6 +26,7 @@ enum NodeCommand { | |||||
| SEND_DATA = 3; | SEND_DATA = 3; | ||||
| FETCH_SERVER = 4; | FETCH_SERVER = 4; | ||||
| FINISH = 5; | FINISH = 5; | ||||
| COLLECTIVE_SEND_DATA = 6; | |||||
| } | } | ||||
| enum NodeRole { | enum NodeRole { | ||||
| @@ -19,19 +19,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| SchedulerNode::~SchedulerNode() { | SchedulerNode::~SchedulerNode() { | ||||
| MS_LOG(INFO) << "Stop scheduler node!"; | MS_LOG(INFO) << "Stop scheduler node!"; | ||||
| if (!is_already_stopped_) { | |||||
| is_already_stopped_ = true; | |||||
| server_->Stop(); | |||||
| if (scheduler_thread_->joinable()) { | |||||
| scheduler_thread_->join(); | |||||
| } | |||||
| if (update_state_thread_->joinable()) { | |||||
| update_state_thread_->join(); | |||||
| } | |||||
| is_ready_ = true; | |||||
| } | |||||
| Stop(); | |||||
| } | } | ||||
| bool SchedulerNode::Start(const uint32_t &timeout) { | bool SchedulerNode::Start(const uint32_t &timeout) { | ||||
| @@ -114,7 +105,6 @@ void SchedulerNode::CreateTcpServer() { | |||||
| MS_LOG(INFO) << "The scheduler node start a tcp server!"; | MS_LOG(INFO) << "The scheduler node start a tcp server!"; | ||||
| server_->Start(); | server_->Start(); | ||||
| }); | }); | ||||
| scheduler_thread_->detach(); | |||||
| } | } | ||||
| void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | ||||
| @@ -186,20 +176,15 @@ void SchedulerNode::StartUpdateClusterStateTimer() { | |||||
| } | } | ||||
| } | } | ||||
| }); | }); | ||||
| update_state_thread_->detach(); | |||||
| } | } | ||||
| bool SchedulerNode::Stop() { | bool SchedulerNode::Stop() { | ||||
| MS_LOG(INFO) << "Stop scheduler node!"; | MS_LOG(INFO) << "Stop scheduler node!"; | ||||
| if (!is_already_stopped_) { | if (!is_already_stopped_) { | ||||
| is_already_stopped_ = true; | is_already_stopped_ = true; | ||||
| update_state_thread_->join(); | |||||
| server_->Stop(); | server_->Stop(); | ||||
| if (scheduler_thread_->joinable()) { | |||||
| scheduler_thread_->join(); | |||||
| } | |||||
| if (update_state_thread_->joinable()) { | |||||
| update_state_thread_->join(); | |||||
| } | |||||
| scheduler_thread_->join(); | |||||
| is_ready_ = true; | is_ready_ = true; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -38,6 +38,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| class SchedulerNode : public Node { | class SchedulerNode : public Node { | ||||
| public: | public: | ||||
| SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | ||||
| @@ -20,18 +20,7 @@ namespace ps { | |||||
| namespace core { | namespace core { | ||||
| ServerNode::~ServerNode() { | ServerNode::~ServerNode() { | ||||
| MS_LOG(INFO) << "Stop server node!"; | MS_LOG(INFO) << "Stop server node!"; | ||||
| if (!is_already_stopped_.load()) { | |||||
| server_->Stop(); | |||||
| client_to_scheduler_->Stop(); | |||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (server_thread_->joinable()) { | |||||
| server_thread_->join(); | |||||
| } | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | |||||
| } | |||||
| Stop(); | |||||
| } | } | ||||
| bool ServerNode::Start(const uint32_t &timeout) { | bool ServerNode::Start(const uint32_t &timeout) { | ||||
| @@ -78,6 +67,10 @@ void ServerNode::CreateTcpServer() { | |||||
| case NodeCommand::SEND_DATA: | case NodeCommand::SEND_DATA: | ||||
| ProcessSendData(server, conn, message); | ProcessSendData(server, conn, message); | ||||
| break; | break; | ||||
| case NodeCommand::COLLECTIVE_SEND_DATA: | |||||
| ProcessCollectiveSendData(server, conn, message); | |||||
| RunReceiveCallback(message); | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | ||||
| } | } | ||||
| @@ -87,7 +80,6 @@ void ServerNode::CreateTcpServer() { | |||||
| MS_LOG(INFO) << "The server node start a tcp server!"; | MS_LOG(INFO) << "The server node start a tcp server!"; | ||||
| server_->Start(); | server_->Start(); | ||||
| }); | }); | ||||
| server_thread_->detach(); | |||||
| } | } | ||||
| void ServerNode::Initialize() { | void ServerNode::Initialize() { | ||||
| @@ -106,27 +98,31 @@ void ServerNode::Initialize() { | |||||
| } | } | ||||
| void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | ||||
| if (request_handler_) { | |||||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||||
| } | |||||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||||
| } | |||||
| void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, | |||||
| const CommMessage &message) { | |||||
| CommMessage comm_message; | |||||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||||
| } | } | ||||
| bool ServerNode::Stop() { | bool ServerNode::Stop() { | ||||
| MS_LOG(INFO) << "Stop server node!"; | MS_LOG(INFO) << "Stop server node!"; | ||||
| if (!is_already_stopped_.load()) { | if (!is_already_stopped_.load()) { | ||||
| server_->Stop(); | |||||
| is_already_stopped_ = true; | |||||
| is_finish_ = true; | |||||
| heart_beat_thread_->join(); | |||||
| client_to_scheduler_->Stop(); | client_to_scheduler_->Stop(); | ||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (server_thread_->joinable()) { | |||||
| server_thread_->join(); | |||||
| if (!connected_nodes_.empty()) { | |||||
| for (auto &connected_node : connected_nodes_) { | |||||
| connected_node.second->Stop(); | |||||
| } | |||||
| } | } | ||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | |||||
| client_to_scheduler_thread_->join(); | |||||
| server_->Stop(); | |||||
| server_thread_->join(); | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -44,8 +44,8 @@ class ServerNode : public AbstractNode { | |||||
| bool Stop() override; | bool Stop() override; | ||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | ||||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, | |||||
| const MessageMeta message_meta, const std::string &message)>; | |||||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta, | |||||
| const std::string &message)>; | |||||
| void set_handler(const RequestHandler &handler); | void set_handler(const RequestHandler &handler); | ||||
| void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | ||||
| @@ -55,6 +55,7 @@ class ServerNode : public AbstractNode { | |||||
| void CreateTcpServer(); | void CreateTcpServer(); | ||||
| void Initialize(); | void Initialize(); | ||||
| void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | ||||
| void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||||
| std::shared_ptr<TcpServer> server_; | std::shared_ptr<TcpServer> server_; | ||||
| std::unique_ptr<std::thread> server_thread_; | std::unique_ptr<std::thread> server_thread_; | ||||
| @@ -51,7 +51,20 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||||
| }); | }); | ||||
| } | } | ||||
| TcpClient::~TcpClient() { Stop(); } | |||||
| TcpClient::~TcpClient() { | |||||
| if (buffer_event_) { | |||||
| bufferevent_free(buffer_event_); | |||||
| buffer_event_ = nullptr; | |||||
| } | |||||
| if (event_timeout_) { | |||||
| event_free(event_timeout_); | |||||
| event_timeout_ = nullptr; | |||||
| } | |||||
| if (event_base_) { | |||||
| event_base_free(event_base_); | |||||
| event_base_ = nullptr; | |||||
| } | |||||
| } | |||||
| std::string TcpClient::GetServerAddress() const { return server_address_; } | std::string TcpClient::GetServerAddress() const { return server_address_; } | ||||
| @@ -69,9 +82,9 @@ bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { | |||||
| void TcpClient::Init() { | void TcpClient::Init() { | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | std::lock_guard<std::mutex> lock(connection_mutex_); | ||||
| if (buffer_event_) { | if (buffer_event_) { | ||||
| return; | |||||
| bufferevent_free(buffer_event_); | |||||
| buffer_event_ = nullptr; | |||||
| } | } | ||||
| is_stop_ = false; | |||||
| if (!CommUtil::CheckIp(server_address_)) { | if (!CommUtil::CheckIp(server_address_)) { | ||||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | ||||
| } | } | ||||
| @@ -82,8 +95,9 @@ void TcpClient::Init() { | |||||
| } | } | ||||
| if (event_base_ == nullptr) { | if (event_base_ == nullptr) { | ||||
| event_base_ = event_base_new(); | event_base_ = event_base_new(); | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | |||||
| is_stop_ = false; | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | |||||
| sockaddr_in sin{}; | sockaddr_in sin{}; | ||||
| if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | ||||
| @@ -127,26 +141,18 @@ void TcpClient::StartWithDelay(int seconds) { | |||||
| void TcpClient::Stop() { | void TcpClient::Stop() { | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | std::lock_guard<std::mutex> lock(connection_mutex_); | ||||
| MS_LOG(INFO) << "Stop tcp client event buffer!"; | |||||
| if (!is_stop_.load()) { | |||||
| if (buffer_event_) { | |||||
| bufferevent_free(buffer_event_); | |||||
| buffer_event_ = nullptr; | |||||
| } | |||||
| if (event_timeout_) { | |||||
| event_free(event_timeout_); | |||||
| event_timeout_ = nullptr; | |||||
| } | |||||
| MS_LOG(INFO) << "Stop tcp client!"; | |||||
| if (event_base_got_break(event_base_)) { | |||||
| MS_LOG(DEBUG) << "The event base has stopped!"; | |||||
| is_stop_ = true; | is_stop_ = true; | ||||
| return; | |||||
| } | } | ||||
| } | |||||
| void TcpClient::StopEventBase() { | |||||
| MS_LOG(INFO) << "Stop tcp client event base!"; | |||||
| int ret = event_base_loopbreak(event_base_); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "Event base loop break failed!"; | |||||
| if (!is_stop_.load()) { | |||||
| is_stop_ = true; | |||||
| int ret = event_base_loopbreak(event_base_); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "Event base loop break failed!"; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -280,6 +286,7 @@ void TcpClient::StartTimer(const uint32_t &time) { | |||||
| void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | ||||
| const event_base &TcpClient::eventbase() { return *event_base_; } | const event_base &TcpClient::eventbase() { return *event_base_; } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -58,7 +58,6 @@ class TcpClient { | |||||
| void Init(); | void Init(); | ||||
| void StartWithDelay(int seconds); | void StartWithDelay(int seconds); | ||||
| void Stop(); | void Stop(); | ||||
| static void StopEventBase(); | |||||
| void Start(); | void Start(); | ||||
| void StartWithNoBlock(); | void StartWithNoBlock(); | ||||
| void SetMessageCallback(const OnMessage &cb); | void SetMessageCallback(const OnMessage &cb); | ||||
| @@ -97,6 +96,7 @@ class TcpClient { | |||||
| std::atomic<bool> is_stop_; | std::atomic<bool> is_stop_; | ||||
| std::atomic<bool> is_connected_; | std::atomic<bool> is_connected_; | ||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,6 +32,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| void TcpConnection::InitConnection() { | void TcpConnection::InitConnection() { | ||||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | tcp_message_handler_.SetCallback([&](const CommMessage &message) { | ||||
| OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | ||||
| @@ -76,7 +77,22 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port) | |||||
| server_port_(port), | server_port_(port), | ||||
| is_stop_(true) {} | is_stop_(true) {} | ||||
| TcpServer::~TcpServer() { Stop(); } | |||||
| TcpServer::~TcpServer() { | |||||
| if (signal_event_ != nullptr) { | |||||
| event_free(signal_event_); | |||||
| signal_event_ = nullptr; | |||||
| } | |||||
| if (listener_ != nullptr) { | |||||
| evconnlistener_free(listener_); | |||||
| listener_ = nullptr; | |||||
| } | |||||
| if (base_ != nullptr) { | |||||
| event_base_free(base_); | |||||
| base_ = nullptr; | |||||
| } | |||||
| } | |||||
| void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | ||||
| const OnAccepted &client_accept) { | const OnAccepted &client_accept) { | ||||
| @@ -136,7 +152,6 @@ void TcpServer::Init() { | |||||
| } | } | ||||
| void TcpServer::Start() { | void TcpServer::Start() { | ||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||||
| MS_LOG(INFO) << "Start tcp server!"; | MS_LOG(INFO) << "Start tcp server!"; | ||||
| MS_EXCEPTION_IF_NULL(base_); | MS_EXCEPTION_IF_NULL(base_); | ||||
| int ret = event_base_dispatch(base_); | int ret = event_base_dispatch(base_); | ||||
| @@ -148,7 +163,7 @@ void TcpServer::Start() { | |||||
| } | } | ||||
| void TcpServer::StartWithNoBlock() { | void TcpServer::StartWithNoBlock() { | ||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| MS_LOG(INFO) << "Start tcp server with no block!"; | MS_LOG(INFO) << "Start tcp server with no block!"; | ||||
| MS_EXCEPTION_IF_NULL(base_); | MS_EXCEPTION_IF_NULL(base_); | ||||
| int ret = event_base_loop(base_, EVLOOP_NONBLOCK); | int ret = event_base_loop(base_, EVLOOP_NONBLOCK); | ||||
| @@ -187,33 +202,25 @@ void TcpServer::StartTimer(const uint32_t &time) { | |||||
| } | } | ||||
| void TcpServer::Stop() { | void TcpServer::Stop() { | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| MS_LOG(INFO) << "Stop tcp server!"; | MS_LOG(INFO) << "Stop tcp server!"; | ||||
| if (event_base_got_break(base_)) { | |||||
| MS_LOG(DEBUG) << "The event base has stopped!"; | |||||
| is_stop_ = true; | |||||
| return; | |||||
| } | |||||
| if (!is_stop_.load()) { | if (!is_stop_.load()) { | ||||
| is_stop_ = true; | |||||
| int ret = event_base_loopbreak(base_); | int ret = event_base_loopbreak(base_); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||||
| } | |||||
| if (signal_event_ != nullptr) { | |||||
| event_free(signal_event_); | |||||
| signal_event_ = nullptr; | |||||
| } | |||||
| if (listener_ != nullptr) { | |||||
| evconnlistener_free(listener_); | |||||
| listener_ = nullptr; | |||||
| MS_LOG(ERROR) << "Event base loop break failed!"; | |||||
| } | } | ||||
| if (base_ != nullptr) { | |||||
| event_base_free(base_); | |||||
| base_ = nullptr; | |||||
| } | |||||
| is_stop_ = true; | |||||
| } | } | ||||
| } | } | ||||
| void TcpServer::SendToAllClients(const char *data, size_t len) { | void TcpServer::SendToAllClients(const char *data, size_t len) { | ||||
| MS_EXCEPTION_IF_NULL(data); | MS_EXCEPTION_IF_NULL(data); | ||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | for (auto it = connections_.begin(); it != connections_.end(); ++it) { | ||||
| it->second->SendMessage(data, len); | it->second->SendMessage(data, len); | ||||
| } | } | ||||
| @@ -221,12 +228,12 @@ void TcpServer::SendToAllClients(const char *data, size_t len) { | |||||
| void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { | void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { | ||||
| MS_EXCEPTION_IF_NULL(connection); | MS_EXCEPTION_IF_NULL(connection); | ||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| connections_.insert(std::make_pair(fd, connection)); | connections_.insert(std::make_pair(fd, connection)); | ||||
| } | } | ||||
| void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | ||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second); | TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second); | ||||
| delete connection; | delete connection; | ||||
| connections_.erase(fd); | connections_.erase(fd); | ||||
| @@ -352,7 +359,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { | |||||
| void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | ||||
| void TcpServer::SendMessage(const CommMessage &message) { | void TcpServer::SendMessage(const CommMessage &message) { | ||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | for (auto it = connections_.begin(); it != connections_.end(); ++it) { | ||||
| SendMessage(*it->second, message); | SendMessage(*it->second, message); | ||||
| @@ -368,6 +375,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); } | |||||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | ||||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -121,7 +121,7 @@ class TcpServer { | |||||
| OnConnected client_connection_; | OnConnected client_connection_; | ||||
| OnDisconnected client_disconnection_; | OnDisconnected client_disconnection_; | ||||
| OnAccepted client_accept_; | OnAccepted client_accept_; | ||||
| std::recursive_mutex connection_mutex_; | |||||
| std::mutex connection_mutex_; | |||||
| OnServerReceiveMessage message_callback_; | OnServerReceiveMessage message_callback_; | ||||
| OnTimerOnce on_timer_once_callback_; | OnTimerOnce on_timer_once_callback_; | ||||
| OnTimer on_timer_callback_; | OnTimer on_timer_callback_; | ||||
| @@ -21,24 +21,7 @@ namespace ps { | |||||
| namespace core { | namespace core { | ||||
| WorkerNode::~WorkerNode() { | WorkerNode::~WorkerNode() { | ||||
| MS_LOG(INFO) << "Stop worker node!"; | MS_LOG(INFO) << "Stop worker node!"; | ||||
| if (!is_already_stopped_.load()) { | |||||
| is_ready_ = true; | |||||
| is_timeout_ = true; | |||||
| client_to_scheduler_->Stop(); | |||||
| if (!connected_nodes_.empty()) { | |||||
| for (auto &connected_node : connected_nodes_) { | |||||
| connected_node.second->Stop(); | |||||
| } | |||||
| } | |||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | |||||
| } | |||||
| Stop(); | |||||
| } | } | ||||
| bool WorkerNode::Start(const uint32_t &timeout) { | bool WorkerNode::Start(const uint32_t &timeout) { | ||||
| MS_LOG(INFO) << "Starting worker node!"; | MS_LOG(INFO) << "Starting worker node!"; | ||||
| @@ -78,19 +61,15 @@ bool WorkerNode::Stop() { | |||||
| if (!is_already_stopped_.load()) { | if (!is_already_stopped_.load()) { | ||||
| is_ready_ = true; | is_ready_ = true; | ||||
| is_timeout_ = true; | is_timeout_ = true; | ||||
| is_finish_ = true; | |||||
| heart_beat_thread_->join(); | |||||
| client_to_scheduler_->Stop(); | client_to_scheduler_->Stop(); | ||||
| if (!connected_nodes_.empty()) { | if (!connected_nodes_.empty()) { | ||||
| for (auto &connected_node : connected_nodes_) { | for (auto &connected_node : connected_nodes_) { | ||||
| connected_node.second->Stop(); | connected_node.second->Stop(); | ||||
| } | } | ||||
| } | } | ||||
| client_to_scheduler_->StopEventBase(); | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| client_to_scheduler_thread_->join(); | |||||
| is_already_stopped_ = true; | is_already_stopped_ = true; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common_test.h" | |||||
| #define protected public | |||||
| #include "ps/core/worker_node.h" | |||||
| #undef protected | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class TestAbstractNode : public UT::Common { | |||||
| public: | |||||
| TestAbstractNode() = default; | |||||
| virtual ~TestAbstractNode() = default; | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(TestAbstractNode, NextExpectedRankRequestId) { | |||||
| WorkerNode workerNode; | |||||
| ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(0)); | |||||
| ASSERT_EQ(2, workerNode.NextExpectedRankRequestId(0)); | |||||
| ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(1)); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||