From: @anancds Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -53,13 +53,20 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | |||
| MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; | |||
| } | |||
| bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) { | |||
| bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) { | |||
| if (node_role != NodeRole::SERVER) { | |||
| MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; | |||
| } | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); | |||
| for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| @@ -82,12 +89,14 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| return SendMessageSync(client, comm_message); | |||
| return SendMessageSync(client, comm_message, timeout); | |||
| } | |||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| @@ -106,6 +115,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| @@ -118,8 +129,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| } | |||
| bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| CommMessage *output, const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| @@ -129,7 +140,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| set_message_callback(request_id, [&]() { | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| *comm_message_resp = res[rank_id]; | |||
| *output = res[rank_id]; | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| }); | |||
| @@ -149,9 +160,9 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| } | |||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp, | |||
| const std::vector<std::string> &data, std::vector<CommMessage> *output, | |||
| const uint32_t &timeout) { | |||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | |||
| @@ -165,7 +176,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| receive_messages_mutex_.lock(); | |||
| auto res = receive_messages_[request_id]; | |||
| for (size_t it = 0; it < len; ++it) { | |||
| (*comm_message_resp).push_back(res[rank_ids.at(it)]); | |||
| (*output).push_back(res[rank_ids.at(it)]); | |||
| } | |||
| receive_messages_.erase(request_id); | |||
| receive_messages_mutex_.unlock(); | |||
| @@ -179,6 +190,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::SEND_DATA); | |||
| message_meta.set_request_id(request_id); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| @@ -200,6 +213,58 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { | |||
| return res; | |||
| } | |||
| uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| const std::string &message, const uint32_t &timeout) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| MessageMeta message_meta; | |||
| message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); | |||
| message_meta.set_rank_id(node_info_.rank_id_); | |||
| message_meta.set_role(node_info_.node_role_); | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message_meta}; | |||
| comm_message.set_data(message); | |||
| auto client = GetOrCreateTcpClient(rank_id); | |||
| return SendMessageAsync(client, comm_message); | |||
| } | |||
| std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, | |||
| const uint32_t &rank_id, CommMessage *output) { | |||
| if (!CommUtil::ValidateRankId(node_role, rank_id)) { | |||
| MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; | |||
| } | |||
| uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); | |||
| if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { | |||
| *output = received_data_[std::make_pair(rank_id, rank_request_id)]; | |||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||
| } else { | |||
| set_receive_callback(rank_id, rank_request_id, [=]() { | |||
| receive_callbacks_mutex_.lock(); | |||
| *output = received_data_[std::make_pair(rank_id, 1)]; | |||
| received_data_.erase(std::make_pair(rank_id, rank_request_id)); | |||
| receive_callbacks_mutex_.unlock(); | |||
| }); | |||
| } | |||
| return std::make_pair(rank_id, rank_request_id); | |||
| } | |||
| bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) { | |||
| std::unique_lock<std::mutex> lock(receive_callbacks_mutex_); | |||
| bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { | |||
| if (actual_rank_request_ids_.count(request_id.first) && | |||
| (actual_rank_request_ids_[request_id.first] >= request_id.second)) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| }); | |||
| return res; | |||
| } | |||
| void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) { | |||
| MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||
| @@ -210,7 +275,6 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); | |||
| } | |||
| }); | |||
| heart_beat_thread_->detach(); | |||
| } | |||
| void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||
| @@ -334,11 +398,9 @@ bool AbstractNode::InitClientToScheduler() { | |||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||
| client_to_scheduler_->Start(); | |||
| }); | |||
| client_to_scheduler_thread_->detach(); | |||
| client_to_scheduler_->set_disconnected_callback([&]() { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval())); | |||
| client_to_scheduler_->Stop(); | |||
| client_to_scheduler_->Init(); | |||
| }); | |||
| return client_to_scheduler_->WaitConnected(); | |||
| @@ -361,6 +423,9 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int & | |||
| ProcessSendDataResp(message); | |||
| RunMessageCallback(message.pb_meta().request_id()); | |||
| break; | |||
| case NodeCommand::COLLECTIVE_SEND_DATA: | |||
| MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| @@ -381,10 +446,12 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con | |||
| return Wait(request_id, timeout); | |||
| } | |||
| void AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) { | |||
| uint64_t request_id = ++next_request_id_; | |||
| message_tracker_[request_id] = std::make_pair(1, 0); | |||
| const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); | |||
| client->SendMessage(message); | |||
| return request_id; | |||
| } | |||
| void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | |||
| @@ -422,12 +489,12 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) { | |||
| message_callbacks_mutex_.unlock(); | |||
| } | |||
| void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { | |||
| if (!message_callback) { | |||
| void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) { | |||
| if (!callback) { | |||
| return; | |||
| } | |||
| std::lock_guard<std::mutex> lock(message_callbacks_mutex_); | |||
| message_callbacks_[request_id] = message_callback; | |||
| message_callbacks_[request_id] = callback; | |||
| } | |||
| void AbstractNode::NotifyMessageArrival(const CommMessage &message) { | |||
| @@ -438,6 +505,61 @@ void AbstractNode::NotifyMessageArrival(const CommMessage &message) { | |||
| message_tracker_[request_id].second++; | |||
| message_tracker_cond_.notify_all(); | |||
| } | |||
| void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, | |||
| const MessageCallback &callback) { | |||
| if (!callback) { | |||
| return; | |||
| } | |||
| std::lock_guard<std::mutex> lock(receive_callbacks_mutex_); | |||
| receive_callbacks_[std::make_pair(rank_id, request_id)] = callback; | |||
| } | |||
| void AbstractNode::RunReceiveCallback(const CommMessage &message) { | |||
| receive_callbacks_mutex_.lock(); | |||
| uint32_t rank_id = message.pb_meta().rank_id(); | |||
| // When receiving a collective message, Then generate rank request id,compare with the desired rank request id, | |||
| // If they are equal, then call the callback function | |||
| uint64_t rank_request_id = NextActualRankRequestId(rank_id); | |||
| received_data_[std::make_pair(rank_id, rank_request_id)] = message; | |||
| auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id)); | |||
| if (it != receive_callbacks_.end()) { | |||
| receive_callbacks_mutex_.unlock(); | |||
| if (it->second) { | |||
| it->second(); | |||
| } | |||
| receive_callbacks_mutex_.lock(); | |||
| receive_cond_.notify_all(); | |||
| receive_callbacks_.erase(it); | |||
| } | |||
| receive_callbacks_mutex_.unlock(); | |||
| } | |||
| uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) { | |||
| std::lock_guard<std::mutex> lock(rank_request_ids_mutex); | |||
| uint64_t rank_request_id = 1; | |||
| if (expected_rank_request_ids_.count(rank_id)) { | |||
| rank_request_id = ++expected_rank_request_ids_[rank_id]; | |||
| expected_rank_request_ids_[rank_id] = rank_request_id; | |||
| } else { | |||
| expected_rank_request_ids_[rank_id] = rank_request_id; | |||
| } | |||
| return rank_request_id; | |||
| } | |||
| uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) { | |||
| std::lock_guard<std::mutex> lock(rank_request_ids_mutex); | |||
| uint64_t rank_request_id = 1; | |||
| if (actual_rank_request_ids_.count(rank_id)) { | |||
| rank_request_id = ++actual_rank_request_ids_[rank_id]; | |||
| actual_rank_request_ids_[rank_id] = rank_request_id; | |||
| } else { | |||
| actual_rank_request_ids_[rank_id] = rank_request_id; | |||
| } | |||
| return rank_request_id; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -34,21 +34,26 @@ class AbstractNode : public Node { | |||
| AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} | |||
| ~AbstractNode() override = default; | |||
| bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Broadcast(const enum NodeRole &node_role, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void set_event_callback(const OnNodeEventMessage &on_node_event_message); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | |||
| const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, | |||
| std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, | |||
| CommMessage *output); | |||
| bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| protected: | |||
| void Register(const std::shared_ptr<TcpClient> &client); | |||
| void ProcessRegisterResp(const CommMessage &message); | |||
| @@ -63,34 +68,51 @@ class AbstractNode : public Node { | |||
| const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); | |||
| bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); | |||
| void ProcessSendDataResp(const CommMessage &message); | |||
| void RunMessageCallback(const uint64_t &request_id); | |||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); | |||
| void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); | |||
| void NotifyMessageArrival(const CommMessage &message); | |||
| void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback); | |||
| void RunReceiveCallback(const CommMessage &message); | |||
| uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); | |||
| uint64_t NextActualRankRequestId(const uint32_t &rank_id); | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| std::unique_ptr<std::thread> client_to_scheduler_thread_; | |||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||
| OnNodeEventMessage on_node_event_message_; | |||
| // the map's key is: <node_role,rank_id>, the map's value is: <ip, port> | |||
| // the key is: <node_role,rank_id>, the value is: <ip, port> | |||
| std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; | |||
| std::mutex client_mutex_; | |||
| // the map's key is: rank_id | |||
| std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_; | |||
| // the map's key is: request_id, the map's value is: <expected responses, actual responses> | |||
| // the key is: request_id, the value is: <expected responses, actual responses> | |||
| std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; | |||
| std::mutex message_tracker_mutex_; | |||
| std::condition_variable message_tracker_cond_; | |||
| // the map's key is: request_id, the map's value is:<rank_id, CommMessage> | |||
| // the key is: request_id, the value is:<rank_id, CommMessage> | |||
| std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; | |||
| std::mutex receive_messages_mutex_; | |||
| // the map's key is: request_id | |||
| // the key is: request_id | |||
| std::unordered_map<uint64_t, MessageCallback> message_callbacks_; | |||
| std::mutex message_callbacks_mutex_; | |||
| // the key is <rank_id, rank_request_id> | |||
| std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_; | |||
| std::mutex receive_callbacks_mutex_; | |||
| // the key is <rank_id, rank_request_id> | |||
| std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_; | |||
| std::condition_variable receive_cond_; | |||
| // the key is rank_id, the value is rank_id's expected request_id | |||
| std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_; | |||
| // the key is rank_id, the value is rank_id's actual request_id | |||
| std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; | |||
| std::mutex rank_request_ids_mutex; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -26,6 +26,7 @@ enum NodeCommand { | |||
| SEND_DATA = 3; | |||
| FETCH_SERVER = 4; | |||
| FINISH = 5; | |||
| COLLECTIVE_SEND_DATA = 6; | |||
| } | |||
| enum NodeRole { | |||
| @@ -19,19 +19,10 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| SchedulerNode::~SchedulerNode() { | |||
| MS_LOG(INFO) << "Stop scheduler node!"; | |||
| if (!is_already_stopped_) { | |||
| is_already_stopped_ = true; | |||
| server_->Stop(); | |||
| if (scheduler_thread_->joinable()) { | |||
| scheduler_thread_->join(); | |||
| } | |||
| if (update_state_thread_->joinable()) { | |||
| update_state_thread_->join(); | |||
| } | |||
| is_ready_ = true; | |||
| } | |||
| Stop(); | |||
| } | |||
| bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| @@ -114,7 +105,6 @@ void SchedulerNode::CreateTcpServer() { | |||
| MS_LOG(INFO) << "The scheduler node start a tcp server!"; | |||
| server_->Start(); | |||
| }); | |||
| scheduler_thread_->detach(); | |||
| } | |||
| void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| @@ -186,20 +176,15 @@ void SchedulerNode::StartUpdateClusterStateTimer() { | |||
| } | |||
| } | |||
| }); | |||
| update_state_thread_->detach(); | |||
| } | |||
| bool SchedulerNode::Stop() { | |||
| MS_LOG(INFO) << "Stop scheduler node!"; | |||
| if (!is_already_stopped_) { | |||
| is_already_stopped_ = true; | |||
| update_state_thread_->join(); | |||
| server_->Stop(); | |||
| if (scheduler_thread_->joinable()) { | |||
| scheduler_thread_->join(); | |||
| } | |||
| if (update_state_thread_->joinable()) { | |||
| update_state_thread_->join(); | |||
| } | |||
| scheduler_thread_->join(); | |||
| is_ready_ = true; | |||
| } | |||
| return true; | |||
| @@ -38,6 +38,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class SchedulerNode : public Node { | |||
| public: | |||
| SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | |||
| @@ -20,18 +20,7 @@ namespace ps { | |||
| namespace core { | |||
| ServerNode::~ServerNode() { | |||
| MS_LOG(INFO) << "Stop server node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| server_->Stop(); | |||
| client_to_scheduler_->Stop(); | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (server_thread_->joinable()) { | |||
| server_thread_->join(); | |||
| } | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| } | |||
| Stop(); | |||
| } | |||
| bool ServerNode::Start(const uint32_t &timeout) { | |||
| @@ -78,6 +67,10 @@ void ServerNode::CreateTcpServer() { | |||
| case NodeCommand::SEND_DATA: | |||
| ProcessSendData(server, conn, message); | |||
| break; | |||
| case NodeCommand::COLLECTIVE_SEND_DATA: | |||
| ProcessCollectiveSendData(server, conn, message); | |||
| RunReceiveCallback(message); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; | |||
| } | |||
| @@ -87,7 +80,6 @@ void ServerNode::CreateTcpServer() { | |||
| MS_LOG(INFO) << "The server node start a tcp server!"; | |||
| server_->Start(); | |||
| }); | |||
| server_thread_->detach(); | |||
| } | |||
| void ServerNode::Initialize() { | |||
| @@ -106,27 +98,31 @@ void ServerNode::Initialize() { | |||
| } | |||
| void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| if (request_handler_) { | |||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||
| } | |||
| request_handler_(server, conn, message.pb_meta(), message.data()); | |||
| } | |||
| void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, | |||
| const CommMessage &message) { | |||
| CommMessage comm_message; | |||
| *comm_message.mutable_pb_meta() = {message.pb_meta()}; | |||
| const_cast<TcpServer &>(server).SendMessage(conn, comm_message); | |||
| } | |||
| bool ServerNode::Stop() { | |||
| MS_LOG(INFO) << "Stop server node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| server_->Stop(); | |||
| is_already_stopped_ = true; | |||
| is_finish_ = true; | |||
| heart_beat_thread_->join(); | |||
| client_to_scheduler_->Stop(); | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (server_thread_->joinable()) { | |||
| server_thread_->join(); | |||
| if (!connected_nodes_.empty()) { | |||
| for (auto &connected_node : connected_nodes_) { | |||
| connected_node.second->Stop(); | |||
| } | |||
| } | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| client_to_scheduler_thread_->join(); | |||
| server_->Stop(); | |||
| server_thread_->join(); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -44,8 +44,8 @@ class ServerNode : public AbstractNode { | |||
| bool Stop() override; | |||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | |||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, | |||
| const MessageMeta message_meta, const std::string &message)>; | |||
| using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta, | |||
| const std::string &message)>; | |||
| void set_handler(const RequestHandler &handler); | |||
| void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, | |||
| @@ -55,6 +55,7 @@ class ServerNode : public AbstractNode { | |||
| void CreateTcpServer(); | |||
| void Initialize(); | |||
| void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> server_thread_; | |||
| @@ -51,7 +51,20 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||
| }); | |||
| } | |||
| TcpClient::~TcpClient() { Stop(); } | |||
| TcpClient::~TcpClient() { | |||
| if (buffer_event_) { | |||
| bufferevent_free(buffer_event_); | |||
| buffer_event_ = nullptr; | |||
| } | |||
| if (event_timeout_) { | |||
| event_free(event_timeout_); | |||
| event_timeout_ = nullptr; | |||
| } | |||
| if (event_base_) { | |||
| event_base_free(event_base_); | |||
| event_base_ = nullptr; | |||
| } | |||
| } | |||
| std::string TcpClient::GetServerAddress() const { return server_address_; } | |||
| @@ -69,9 +82,9 @@ bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { | |||
| void TcpClient::Init() { | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| if (buffer_event_) { | |||
| return; | |||
| bufferevent_free(buffer_event_); | |||
| buffer_event_ = nullptr; | |||
| } | |||
| is_stop_ = false; | |||
| if (!CommUtil::CheckIp(server_address_)) { | |||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| @@ -82,8 +95,9 @@ void TcpClient::Init() { | |||
| } | |||
| if (event_base_ == nullptr) { | |||
| event_base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| is_stop_ = false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| sockaddr_in sin{}; | |||
| if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | |||
| @@ -127,26 +141,18 @@ void TcpClient::StartWithDelay(int seconds) { | |||
| void TcpClient::Stop() { | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Stop tcp client event buffer!"; | |||
| if (!is_stop_.load()) { | |||
| if (buffer_event_) { | |||
| bufferevent_free(buffer_event_); | |||
| buffer_event_ = nullptr; | |||
| } | |||
| if (event_timeout_) { | |||
| event_free(event_timeout_); | |||
| event_timeout_ = nullptr; | |||
| } | |||
| MS_LOG(INFO) << "Stop tcp client!"; | |||
| if (event_base_got_break(event_base_)) { | |||
| MS_LOG(DEBUG) << "The event base has stopped!"; | |||
| is_stop_ = true; | |||
| return; | |||
| } | |||
| } | |||
| void TcpClient::StopEventBase() { | |||
| MS_LOG(INFO) << "Stop tcp client event base!"; | |||
| int ret = event_base_loopbreak(event_base_); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "Event base loop break failed!"; | |||
| if (!is_stop_.load()) { | |||
| is_stop_ = true; | |||
| int ret = event_base_loopbreak(event_base_); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "Event base loop break failed!"; | |||
| } | |||
| } | |||
| } | |||
| @@ -280,6 +286,7 @@ void TcpClient::StartTimer(const uint32_t &time) { | |||
| void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | |||
| const event_base &TcpClient::eventbase() { return *event_base_; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -58,7 +58,6 @@ class TcpClient { | |||
| void Init(); | |||
| void StartWithDelay(int seconds); | |||
| void Stop(); | |||
| static void StopEventBase(); | |||
| void Start(); | |||
| void StartWithNoBlock(); | |||
| void SetMessageCallback(const OnMessage &cb); | |||
| @@ -97,6 +96,7 @@ class TcpClient { | |||
| std::atomic<bool> is_stop_; | |||
| std::atomic<bool> is_connected_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void TcpConnection::InitConnection() { | |||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | |||
| OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | |||
| @@ -76,7 +77,22 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port) | |||
| server_port_(port), | |||
| is_stop_(true) {} | |||
| TcpServer::~TcpServer() { Stop(); } | |||
| TcpServer::~TcpServer() { | |||
| if (signal_event_ != nullptr) { | |||
| event_free(signal_event_); | |||
| signal_event_ = nullptr; | |||
| } | |||
| if (listener_ != nullptr) { | |||
| evconnlistener_free(listener_); | |||
| listener_ = nullptr; | |||
| } | |||
| if (base_ != nullptr) { | |||
| event_base_free(base_); | |||
| base_ = nullptr; | |||
| } | |||
| } | |||
| void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | |||
| const OnAccepted &client_accept) { | |||
| @@ -136,7 +152,6 @@ void TcpServer::Init() { | |||
| } | |||
| void TcpServer::Start() { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Start tcp server!"; | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| int ret = event_base_dispatch(base_); | |||
| @@ -148,7 +163,7 @@ void TcpServer::Start() { | |||
| } | |||
| void TcpServer::StartWithNoBlock() { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Start tcp server with no block!"; | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| int ret = event_base_loop(base_, EVLOOP_NONBLOCK); | |||
| @@ -187,33 +202,25 @@ void TcpServer::StartTimer(const uint32_t &time) { | |||
| } | |||
| void TcpServer::Stop() { | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Stop tcp server!"; | |||
| if (event_base_got_break(base_)) { | |||
| MS_LOG(DEBUG) << "The event base has stopped!"; | |||
| is_stop_ = true; | |||
| return; | |||
| } | |||
| if (!is_stop_.load()) { | |||
| is_stop_ = true; | |||
| int ret = event_base_loopbreak(base_); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||
| } | |||
| if (signal_event_ != nullptr) { | |||
| event_free(signal_event_); | |||
| signal_event_ = nullptr; | |||
| } | |||
| if (listener_ != nullptr) { | |||
| evconnlistener_free(listener_); | |||
| listener_ = nullptr; | |||
| MS_LOG(ERROR) << "Event base loop break failed!"; | |||
| } | |||
| if (base_ != nullptr) { | |||
| event_base_free(base_); | |||
| base_ = nullptr; | |||
| } | |||
| is_stop_ = true; | |||
| } | |||
| } | |||
| void TcpServer::SendToAllClients(const char *data, size_t len) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | |||
| it->second->SendMessage(data, len); | |||
| } | |||
| @@ -221,12 +228,12 @@ void TcpServer::SendToAllClients(const char *data, size_t len) { | |||
| void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { | |||
| MS_EXCEPTION_IF_NULL(connection); | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| connections_.insert(std::make_pair(fd, connection)); | |||
| } | |||
| void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second); | |||
| delete connection; | |||
| connections_.erase(fd); | |||
| @@ -352,7 +359,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { | |||
| void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | |||
| void TcpServer::SendMessage(const CommMessage &message) { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | |||
| SendMessage(*it->second, message); | |||
| @@ -368,6 +375,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); } | |||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | |||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -121,7 +121,7 @@ class TcpServer { | |||
| OnConnected client_connection_; | |||
| OnDisconnected client_disconnection_; | |||
| OnAccepted client_accept_; | |||
| std::recursive_mutex connection_mutex_; | |||
| std::mutex connection_mutex_; | |||
| OnServerReceiveMessage message_callback_; | |||
| OnTimerOnce on_timer_once_callback_; | |||
| OnTimer on_timer_callback_; | |||
| @@ -21,24 +21,7 @@ namespace ps { | |||
| namespace core { | |||
| WorkerNode::~WorkerNode() { | |||
| MS_LOG(INFO) << "Stop worker node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| is_ready_ = true; | |||
| is_timeout_ = true; | |||
| client_to_scheduler_->Stop(); | |||
| if (!connected_nodes_.empty()) { | |||
| for (auto &connected_node : connected_nodes_) { | |||
| connected_node.second->Stop(); | |||
| } | |||
| } | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| } | |||
| Stop(); | |||
| } | |||
| bool WorkerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(INFO) << "Starting worker node!"; | |||
| @@ -78,19 +61,15 @@ bool WorkerNode::Stop() { | |||
| if (!is_already_stopped_.load()) { | |||
| is_ready_ = true; | |||
| is_timeout_ = true; | |||
| is_finish_ = true; | |||
| heart_beat_thread_->join(); | |||
| client_to_scheduler_->Stop(); | |||
| if (!connected_nodes_.empty()) { | |||
| for (auto &connected_node : connected_nodes_) { | |||
| connected_node.second->Stop(); | |||
| } | |||
| } | |||
| client_to_scheduler_->StopEventBase(); | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| client_to_scheduler_thread_->join(); | |||
| is_already_stopped_ = true; | |||
| } | |||
| return true; | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common_test.h" | |||
| #define protected public | |||
| #include "ps/core/worker_node.h" | |||
| #undef protected | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class TestAbstractNode : public UT::Common { | |||
| public: | |||
| TestAbstractNode() = default; | |||
| virtual ~TestAbstractNode() = default; | |||
| void SetUp() override {} | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestAbstractNode, NextExpectedRankRequestId) { | |||
| WorkerNode workerNode; | |||
| ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(0)); | |||
| ASSERT_EQ(2, workerNode.NextExpectedRankRequestId(0)); | |||
| ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(1)); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||