You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
  17. #define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "ps/core/communicator/communicator_base.h"
  22. #include "ps/core/communicator/tcp_communicator.h"
  23. #include "ps/core/communicator/task_executor.h"
  24. #include "ps/core/file_configuration.h"
  25. #include "fl/server/common.h"
  26. #include "fl/server/executor.h"
  27. #include "fl/server/iteration.h"
  28. #ifdef ENABLE_ARMOUR
  29. #include "fl/armour/cipher/cipher_init.h"
  30. #endif
  31. namespace mindspore {
  32. namespace fl {
  33. namespace server {
  34. // The sleeping time of the server thread before the networking is completed.
  35. constexpr uint32_t kServerSleepTimeForNetworking = 1000;
  36. // Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
  37. class Server {
  38. public:
  39. static Server &GetInstance() {
  40. static Server instance;
  41. return instance;
  42. }
  43. void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
  44. const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold);
  45. // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
  46. // blocked until the server is finalized.
  47. // func_graph is the frontend graph which will be parse in server's exector and aggregator.
  48. void Run();
  49. void SwitchToSafeMode();
  50. void CancelSafeMode();
  51. bool IsSafeMode() const;
  52. void WaitExitSafeMode() const;
  53. // Whether the training job of the server is enabled.
  54. InstanceState instance_state() const;
  55. private:
  56. Server()
  57. : server_node_(nullptr),
  58. task_executor_(nullptr),
  59. use_tcp_(false),
  60. use_http_(false),
  61. http_port_(0),
  62. func_graph_(nullptr),
  63. executor_threshold_(0),
  64. communicator_with_server_(nullptr),
  65. communicators_with_worker_({}),
  66. iteration_(nullptr),
  67. safemode_(true),
  68. scheduler_ip_(""),
  69. scheduler_port_(0),
  70. server_num_(0),
  71. worker_num_(0),
  72. fl_server_port_(0),
  73. cipher_initial_client_cnt_(0),
  74. cipher_exchange_secrets_cnt_(0),
  75. cipher_share_secrets_cnt_(0),
  76. cipher_get_clientlist_cnt_(0),
  77. cipher_reconstruct_secrets_up_cnt_(0),
  78. cipher_reconstruct_secrets_down_cnt_(0),
  79. cipher_time_window_(0) {}
  80. ~Server() = default;
  81. Server(const Server &) = delete;
  82. Server &operator=(const Server &) = delete;
  83. // Load variables which is set by ps_context.
  84. void InitServerContext();
  85. // Try to recover server config from persistent storage.
  86. void Recovery();
  87. // Initialize the server cluster, server node and communicators.
  88. void InitCluster();
  89. bool InitCommunicatorWithServer();
  90. bool InitCommunicatorWithWorker();
  91. // Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
  92. void InitIteration();
  93. // Register all message and event callbacks for communicators(TCP and HTTP). This method must be called before
  94. // communicators are started.
  95. void RegisterCommCallbacks();
  96. // Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
  97. void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
  98. // Register message callbacks. These messages are mainly from scheduler.
  99. void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
  100. // Initialize executor according to the server mode.
  101. void InitExecutor();
  102. // Initialize cipher according to the public param.
  103. void InitCipher();
  104. // Create round kernels and bind these kernels with corresponding Round.
  105. void RegisterRoundKernel();
  106. void InitMetrics();
  107. // The communicators should be started after all initializations are completed.
  108. void StartCommunicator();
  109. // The barriers before scaling operations.
  110. void ProcessBeforeScalingOut();
  111. void ProcessBeforeScalingIn();
  112. // The handlers after scheduler's scaling operations are done.
  113. void ProcessAfterScalingOut();
  114. void ProcessAfterScalingIn();
  115. // Handlers for enableFLS/disableFLS requests from the scheduler.
  116. void HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  117. void HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  118. // Finish current instance and start a new one. FLPlan could be changed in this method.
  119. void HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  120. // Query current instance information.
  121. void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  122. // The server node is initialized in Server.
  123. std::shared_ptr<ps::core::ServerNode> server_node_;
  124. // The task executor of the communicators. This helps server to handle network message concurrently. The tasks
  125. // submitted to this task executor is asynchronous.
  126. std::shared_ptr<ps::core::TaskExecutor> task_executor_;
  127. // Which protocol should communicators use.
  128. bool use_tcp_;
  129. bool use_http_;
  130. uint16_t http_port_;
  131. // The configure of all rounds.
  132. std::vector<RoundConfig> rounds_config_;
  133. CipherConfig cipher_config_;
  134. // The graph passed by the frontend without backend optimizing.
  135. FuncGraphPtr func_graph_;
  136. // The threshold count for executor to do aggregation or optimizing.
  137. size_t executor_threshold_;
  138. // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
  139. // operations, etc.
  140. std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_;
  141. // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
  142. // In some cases, both types should be supported in one distributed training job. So here we may have multiple
  143. // communicators.
  144. std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
  145. // Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
  146. std::mutex scaling_mtx_;
  147. // Iteration consists of multiple kinds of rounds.
  148. Iteration *iteration_;
  149. // The flag that represents whether server is in safemode.
  150. // If true, the server is not available to workers and clients.
  151. std::atomic_bool safemode_;
  152. // Variables set by ps context.
  153. #ifdef ENABLE_ARMOUR
  154. armour::CipherInit *cipher_init_;
  155. #endif
  156. std::string scheduler_ip_;
  157. uint16_t scheduler_port_;
  158. uint32_t server_num_;
  159. uint32_t worker_num_;
  160. uint16_t fl_server_port_;
  161. size_t cipher_initial_client_cnt_;
  162. size_t cipher_exchange_secrets_cnt_;
  163. size_t cipher_share_secrets_cnt_;
  164. size_t cipher_get_clientlist_cnt_;
  165. size_t cipher_reconstruct_secrets_up_cnt_;
  166. size_t cipher_reconstruct_secrets_down_cnt_;
  167. uint64_t cipher_time_window_;
  168. };
  169. } // namespace server
  170. } // namespace fl
  171. } // namespace mindspore
  172. #endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_