You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

distributed_count_service.h 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
  17. #define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
  18. #include <set>
  19. #include <string>
  20. #include <memory>
  21. #include <unordered_map>
  22. #include "proto/ps.pb.h"
  23. #include "ps/server/common.h"
  24. #include "ps/core/server_node.h"
  25. #include "ps/core/communicator/tcp_communicator.h"
  26. namespace mindspore {
  27. namespace ps {
  28. namespace server {
  29. // The callbacks for the first count and last count event.
  30. typedef struct {
  31. MessageCallback first_count_handler;
  32. MessageCallback last_count_handler;
  33. } CounterHandlers;
  34. // DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds,
  35. // aggregation counting, etc.
  36. // The counting could be called by any server, but only one server has the information
  37. // of the cluster count and we mark this server as the counting server. Other servers must communicate with this
  38. // counting server to increase/query count number.
  39. // On the first count or last count event, DistributedCountService on the counting server triggers the event on other
  40. // servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency.
  41. class DistributedCountService {
  42. public:
  43. static DistributedCountService &GetInstance() {
  44. static DistributedCountService instance;
  45. return instance;
  46. }
  47. // Initialize counter service with the server node because communication is needed.
  48. void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
  49. // Register counter to the counting server for the name with its threshold count in server cluster dimension and
  50. // first/last count event callbacks.
  51. void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
  52. // Report a count to the counting server. Parameter 'id' is in case of repeated counting.
  53. bool Count(const std::string &name, const std::string &id);
  54. // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
  55. // this method returns true.
  56. bool CountReachThreshold(const std::string &name);
  57. // Reset the count of the name to 0.
  58. void ResetCounter(const std::string &name);
  59. // Returns the server rank because in some cases the callers use this rank as the 'id' for method
  60. // Count.
  61. uint32_t local_rank() { return local_rank_; }
  62. private:
  63. DistributedCountService() = default;
  64. ~DistributedCountService() = default;
  65. DistributedCountService(const DistributedCountService &) = delete;
  66. DistributedCountService &operator=(const DistributedCountService &) = delete;
  67. // Register callbacks of the counting server to handle messages sent by the other servers.
  68. void RegisterCallback();
  69. // Callback for the reporting count message from other servers. Only counting server will call this method.
  70. void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
  71. // Callback for the querying whether threshold count is reached message from other servers. Only counting
  72. // server will call this method.
  73. void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message);
  74. // Callback for the first/last event message from the counting server. Only other servers will call this
  75. // method.
  76. void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
  77. // Call the callbacks when the first/last count event is triggered.
  78. void TriggerCounterEvent(const std::string &name);
  79. void TriggerFirstCountEvent(const std::string &name);
  80. void TriggerLastCountEvent(const std::string &name);
  81. // Members for the communication between counting server and other servers.
  82. std::shared_ptr<core::ServerNode> server_node_;
  83. std::shared_ptr<core::TcpCommunicator> communicator_;
  84. uint32_t local_rank_;
  85. uint32_t server_num_;
  86. // Only one server will be set to do the real counting.
  87. uint32_t counting_server_rank_;
  88. // Key: name, e.g, startFLJob, updateModel, push.
  89. // Value: a set of id without repeatation because each work may report multiple times.
  90. std::unordered_map<std::string, std::set<std::string>> global_current_count_;
  91. // Key: name, e.g, StartFLJobCount.
  92. // Value: global threshold count in the server cluster dimension for this name.
  93. std::unordered_map<std::string, size_t> global_threshold_count_;
  94. // First/last count event callbacks of the name.
  95. std::unordered_map<std::string, CounterHandlers> counter_handlers_;
  96. // Because the count is increased/queried conccurently, we must ensure the operations are threadsafe.
  97. std::unordered_map<std::string, std::mutex> mutex_;
  98. };
  99. } // namespace server
  100. } // namespace ps
  101. } // namespace mindspore
  102. #endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_