/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ps/server/distributed_count_service.h" #include #include #include namespace mindspore { namespace ps { namespace server { void DistributedCountService::Initialize(const std::shared_ptr &server_node, uint32_t counting_server_rank) { server_node_ = server_node; MS_EXCEPTION_IF_NULL(server_node_); communicator_ = std::dynamic_pointer_cast(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr)); MS_EXCEPTION_IF_NULL(communicator_); local_rank_ = server_node_->rank_id(); server_num_ = PSContext::instance()->initial_server_num(); counting_server_rank_ = counting_server_rank; RegisterCallback(); return; } void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers) { if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) { MS_LOG(EXCEPTION) << "First count handler or last count handler is not set."; return; } if (global_threshold_count_.count(name) != 0) { MS_LOG(ERROR) << "Counter for " << name << " is already set."; return; } MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count; // If the server is the leader server, it needs to set the counter handlers and do the real counting. if (local_rank_ == counting_server_rank_) { global_current_count_[name] = {}; global_threshold_count_[name] = global_threshold_count; mutex_[name]; } counter_handlers_[name] = counter_handlers; return; } bool DistributedCountService::Count(const std::string &name, const std::string &id) { MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; if (local_rank_ == counting_server_rank_) { if (global_threshold_count_.count(name) == 0) { MS_LOG(ERROR) << "Counter for " << name << " is not registered."; return false; } std::unique_lock lock(mutex_[name]); if (global_current_count_[name].size() >= global_threshold_count_[name]) { MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is " << global_threshold_count_[name]; return false; } MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; global_current_count_[name].insert(id); TriggerCounterEvent(name); } else { // If this server is a follower server, it needs to send CountRequest to the leader server. CountRequest report_count_req; report_count_req.set_name(name); report_count_req.set_id(id); std::shared_ptr> report_cnt_rsp_msg = nullptr; if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount, &report_cnt_rsp_msg)) { MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name; return false; } CountResponse count_rsp; count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size()); if (!count_rsp.result()) { MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); return false; } } return true; } bool DistributedCountService::CountReachThreshold(const std::string &name) { MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; if (local_rank_ == counting_server_rank_) { if (global_threshold_count_.count(name) == 0) { MS_LOG(ERROR) << "Counter for " << name << " is not set."; return false; } std::unique_lock lock(mutex_[name]); return global_current_count_[name].size() == global_threshold_count_[name]; } else { CountReachThresholdRequest count_reach_threshold_req; count_reach_threshold_req.set_name(name); std::shared_ptr> query_cnt_enough_rsp_msg = nullptr; if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_, core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; return false; } CountReachThresholdResponse count_reach_threshold_rsp; count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); return count_reach_threshold_rsp.is_enough(); } } void DistributedCountService::ResetCounter(const std::string &name) { if (local_rank_ == counting_server_rank_) { MS_LOG(INFO) << "Leader server reset count for " << name; global_current_count_[name].clear(); } return; } void DistributedCountService::RegisterCallback() { if (local_rank_ == counting_server_rank_) { communicator_->RegisterMsgCallBack( "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1)); communicator_->RegisterMsgCallBack( "countReachThreshold", std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1)); } // The callback of first/last event must be set in both leader server and follower servers. communicator_->RegisterMsgCallBack( "counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1)); } void DistributedCountService::HandleCountRequest(const std::shared_ptr &message) { if (message == nullptr) { MS_LOG(ERROR) << "Message is nullptr."; return; } CountRequest report_count_req; report_count_req.ParseFromArray(message->data(), message->len()); const std::string &name = report_count_req.name(); const std::string &id = report_count_req.id(); CountResponse count_rsp; std::unique_lock lock(mutex_[name]); // If leader server has no counter for the name registered, return an error. if (global_threshold_count_.count(name) == 0) { std::string reason = "Counter for " + name + " is not registered."; count_rsp.set_result(false); count_rsp.set_reason(reason); MS_LOG(ERROR) << reason; communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); return; } // If leader server already has enough count for the name, return an error. if (global_current_count_[name].size() >= global_threshold_count_[name]) { std::string reason = "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]); count_rsp.set_result(false); count_rsp.set_reason(reason); MS_LOG(ERROR) << reason; communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); return; } // Insert the id for the counter, which means the count for the name is increased. MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; global_current_count_[name].insert(id); TriggerCounterEvent(name); count_rsp.set_result(true); count_rsp.set_reason("success"); communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); return; } void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr &message) { if (message == nullptr) { MS_LOG(ERROR) << "Message is nullptr."; return; } CountReachThresholdRequest count_reach_threshold_req; count_reach_threshold_req.ParseFromArray(message->data(), message->len()); const std::string &name = count_reach_threshold_req.name(); std::unique_lock lock(mutex_[name]); if (global_threshold_count_.count(name) == 0) { MS_LOG(ERROR) << "Counter for " << name << " is not registered."; return; } CountReachThresholdResponse count_reach_threshold_rsp; count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(), count_reach_threshold_rsp.SerializeAsString().size(), message); return; } void DistributedCountService::HandleCounterEvent(const std::shared_ptr &message) { if (message == nullptr) { MS_LOG(ERROR) << "Message is nullptr."; return; } // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the // callbacks. std::string couter_event_rsp_msg = "success"; communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message); CounterEvent counter_event; counter_event.ParseFromArray(message->data(), message->len()); const auto &type = counter_event.type(); const auto &name = counter_event.name(); MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name; if (type == CounterEventType::FIRST_CNT) { counter_handlers_[name].first_count_handler(message); } else if (type == CounterEventType::LAST_CNT) { counter_handlers_[name].last_count_handler(message); } else { MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid."; return; } return; } void DistributedCountService::TriggerCounterEvent(const std::string &name) { MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size() << ", threshold count is " << global_threshold_count_[name]; // The threshold count may be 1 so the first and last count event should be both activated. if (global_current_count_[name].size() == 1) { TriggerFirstCountEvent(name); } if (global_current_count_[name].size() == global_threshold_count_[name]) { TriggerLastCountEvent(name); } return; } void DistributedCountService::TriggerFirstCountEvent(const std::string &name) { MS_LOG(INFO) << "Activating first count event for " << name; CounterEvent first_count_event; first_count_event.set_type(CounterEventType::FIRST_CNT); first_count_event.set_name(name); // Broadcast to all follower servers. for (uint32_t i = 1; i < server_num_; i++) { if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; return; } } // Leader server directly calls the callback. counter_handlers_[name].first_count_handler(nullptr); return; } void DistributedCountService::TriggerLastCountEvent(const std::string &name) { MS_LOG(INFO) << "Activating last count event for " << name; CounterEvent last_count_event; last_count_event.set_type(CounterEventType::LAST_CNT); last_count_event.set_name(name); // Broadcast to all follower servers. for (uint32_t i = 1; i < server_num_; i++) { if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; return; } } // Leader server directly calls the callback. counter_handlers_[name].last_count_handler(nullptr); return; } } // namespace server } // namespace ps } // namespace mindspore