You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

distributed_count_service.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/server/distributed_count_service.h"
  17. #include <string>
  18. #include <memory>
  19. #include <vector>
  20. namespace mindspore {
  21. namespace ps {
  22. namespace server {
  23. void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node,
  24. uint32_t counting_server_rank) {
  25. server_node_ = server_node;
  26. MS_EXCEPTION_IF_NULL(server_node_);
  27. communicator_ =
  28. std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
  29. MS_EXCEPTION_IF_NULL(communicator_);
  30. local_rank_ = server_node_->rank_id();
  31. server_num_ = PSContext::instance()->initial_server_num();
  32. counting_server_rank_ = counting_server_rank;
  33. RegisterCallback();
  34. return;
  35. }
  36. void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count,
  37. const CounterHandlers &counter_handlers) {
  38. if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) {
  39. MS_LOG(EXCEPTION) << "First count handler or last count handler is not set.";
  40. return;
  41. }
  42. if (global_threshold_count_.count(name) != 0) {
  43. MS_LOG(ERROR) << "Counter for " << name << " is already set.";
  44. return;
  45. }
  46. MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count;
  47. // If the server is the leader server, it needs to set the counter handlers and do the real counting.
  48. if (local_rank_ == counting_server_rank_) {
  49. global_current_count_[name] = {};
  50. global_threshold_count_[name] = global_threshold_count;
  51. mutex_[name];
  52. }
  53. counter_handlers_[name] = counter_handlers;
  54. return;
  55. }
  56. bool DistributedCountService::Count(const std::string &name, const std::string &id) {
  57. MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
  58. if (local_rank_ == counting_server_rank_) {
  59. if (global_threshold_count_.count(name) == 0) {
  60. MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
  61. return false;
  62. }
  63. std::unique_lock<std::mutex> lock(mutex_[name]);
  64. if (global_current_count_[name].size() >= global_threshold_count_[name]) {
  65. MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is "
  66. << global_threshold_count_[name];
  67. return false;
  68. }
  69. MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
  70. global_current_count_[name].insert(id);
  71. TriggerCounterEvent(name);
  72. } else {
  73. // If this server is a follower server, it needs to send CountRequest to the leader server.
  74. CountRequest report_count_req;
  75. report_count_req.set_name(name);
  76. report_count_req.set_id(id);
  77. std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
  78. if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount,
  79. &report_cnt_rsp_msg)) {
  80. MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
  81. return false;
  82. }
  83. CountResponse count_rsp;
  84. count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size());
  85. if (!count_rsp.result()) {
  86. MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
  87. return false;
  88. }
  89. }
  90. return true;
  91. }
  92. bool DistributedCountService::CountReachThreshold(const std::string &name) {
  93. MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
  94. if (local_rank_ == counting_server_rank_) {
  95. if (global_threshold_count_.count(name) == 0) {
  96. MS_LOG(ERROR) << "Counter for " << name << " is not set.";
  97. return false;
  98. }
  99. std::unique_lock<std::mutex> lock(mutex_[name]);
  100. return global_current_count_[name].size() == global_threshold_count_[name];
  101. } else {
  102. CountReachThresholdRequest count_reach_threshold_req;
  103. count_reach_threshold_req.set_name(name);
  104. std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
  105. if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
  106. core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
  107. MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
  108. return false;
  109. }
  110. CountReachThresholdResponse count_reach_threshold_rsp;
  111. count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
  112. return count_reach_threshold_rsp.is_enough();
  113. }
  114. }
  115. void DistributedCountService::ResetCounter(const std::string &name) {
  116. if (local_rank_ == counting_server_rank_) {
  117. MS_LOG(INFO) << "Leader server reset count for " << name;
  118. global_current_count_[name].clear();
  119. }
  120. return;
  121. }
  122. void DistributedCountService::RegisterCallback() {
  123. if (local_rank_ == counting_server_rank_) {
  124. communicator_->RegisterMsgCallBack(
  125. "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
  126. communicator_->RegisterMsgCallBack(
  127. "countReachThreshold",
  128. std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
  129. }
  130. // The callback of first/last event must be set in both leader server and follower servers.
  131. communicator_->RegisterMsgCallBack(
  132. "counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
  133. }
  134. void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
  135. if (message == nullptr) {
  136. MS_LOG(ERROR) << "Message is nullptr.";
  137. return;
  138. }
  139. CountRequest report_count_req;
  140. report_count_req.ParseFromArray(message->data(), message->len());
  141. const std::string &name = report_count_req.name();
  142. const std::string &id = report_count_req.id();
  143. CountResponse count_rsp;
  144. std::unique_lock<std::mutex> lock(mutex_[name]);
  145. // If leader server has no counter for the name registered, return an error.
  146. if (global_threshold_count_.count(name) == 0) {
  147. std::string reason = "Counter for " + name + " is not registered.";
  148. count_rsp.set_result(false);
  149. count_rsp.set_reason(reason);
  150. MS_LOG(ERROR) << reason;
  151. communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
  152. return;
  153. }
  154. // If leader server already has enough count for the name, return an error.
  155. if (global_current_count_[name].size() >= global_threshold_count_[name]) {
  156. std::string reason =
  157. "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]);
  158. count_rsp.set_result(false);
  159. count_rsp.set_reason(reason);
  160. MS_LOG(ERROR) << reason;
  161. communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
  162. return;
  163. }
  164. // Insert the id for the counter, which means the count for the name is increased.
  165. MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
  166. global_current_count_[name].insert(id);
  167. TriggerCounterEvent(name);
  168. count_rsp.set_result(true);
  169. count_rsp.set_reason("success");
  170. communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
  171. return;
  172. }
  173. void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) {
  174. if (message == nullptr) {
  175. MS_LOG(ERROR) << "Message is nullptr.";
  176. return;
  177. }
  178. CountReachThresholdRequest count_reach_threshold_req;
  179. count_reach_threshold_req.ParseFromArray(message->data(), message->len());
  180. const std::string &name = count_reach_threshold_req.name();
  181. std::unique_lock<std::mutex> lock(mutex_[name]);
  182. if (global_threshold_count_.count(name) == 0) {
  183. MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
  184. return;
  185. }
  186. CountReachThresholdResponse count_reach_threshold_rsp;
  187. count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
  188. communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(),
  189. count_reach_threshold_rsp.SerializeAsString().size(), message);
  190. return;
  191. }
  192. void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) {
  193. if (message == nullptr) {
  194. MS_LOG(ERROR) << "Message is nullptr.";
  195. return;
  196. }
  197. // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
  198. // callbacks.
  199. std::string couter_event_rsp_msg = "success";
  200. communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message);
  201. CounterEvent counter_event;
  202. counter_event.ParseFromArray(message->data(), message->len());
  203. const auto &type = counter_event.type();
  204. const auto &name = counter_event.name();
  205. MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
  206. if (type == CounterEventType::FIRST_CNT) {
  207. counter_handlers_[name].first_count_handler(message);
  208. } else if (type == CounterEventType::LAST_CNT) {
  209. counter_handlers_[name].last_count_handler(message);
  210. } else {
  211. MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid.";
  212. return;
  213. }
  214. return;
  215. }
  216. void DistributedCountService::TriggerCounterEvent(const std::string &name) {
  217. MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
  218. << ", threshold count is " << global_threshold_count_[name];
  219. // The threshold count may be 1 so the first and last count event should be both activated.
  220. if (global_current_count_[name].size() == 1) {
  221. TriggerFirstCountEvent(name);
  222. }
  223. if (global_current_count_[name].size() == global_threshold_count_[name]) {
  224. TriggerLastCountEvent(name);
  225. }
  226. return;
  227. }
  228. void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
  229. MS_LOG(INFO) << "Activating first count event for " << name;
  230. CounterEvent first_count_event;
  231. first_count_event.set_type(CounterEventType::FIRST_CNT);
  232. first_count_event.set_name(name);
  233. // Broadcast to all follower servers.
  234. for (uint32_t i = 1; i < server_num_; i++) {
  235. if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
  236. MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
  237. return;
  238. }
  239. }
  240. // Leader server directly calls the callback.
  241. counter_handlers_[name].first_count_handler(nullptr);
  242. return;
  243. }
  244. void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
  245. MS_LOG(INFO) << "Activating last count event for " << name;
  246. CounterEvent last_count_event;
  247. last_count_event.set_type(CounterEventType::LAST_CNT);
  248. last_count_event.set_name(name);
  249. // Broadcast to all follower servers.
  250. for (uint32_t i = 1; i < server_num_; i++) {
  251. if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
  252. MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
  253. return;
  254. }
  255. }
  256. // Leader server directly calls the callback.
  257. counter_handlers_[name].last_count_handler(nullptr);
  258. return;
  259. }
  260. } // namespace server
  261. } // namespace ps
  262. } // namespace mindspore