You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.h 9.2 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
  17. #define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "ps/core/communicator/communicator_base.h"
  22. #include "ps/core/communicator/tcp_communicator.h"
  23. #include "ps/core/communicator/task_executor.h"
  24. #include "ps/core/file_configuration.h"
  25. #ifdef ENABLE_ARMOUR
  26. #include "fl/armour/cipher/cipher_init.h"
  27. #endif
  28. #include "fl/server/common.h"
  29. #include "fl/server/executor.h"
  30. #include "fl/server/iteration.h"
  31. namespace mindspore {
  32. namespace fl {
  33. namespace server {
  34. // The sleeping time of the server thread before the networking is completed.
  35. constexpr uint32_t kServerSleepTimeForNetworking = 1000;
  36. constexpr uint64_t kDefaultReplayAttackTimeDiff = 60000;
  37. // Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
  38. class Server {
  39. public:
  40. static Server &GetInstance() {
  41. static Server instance;
  42. return instance;
  43. }
  44. void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
  45. const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold);
  46. // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
  47. // blocked until the server is finalized.
  48. // func_graph is the frontend graph which will be parse in server's exector and aggregator.
  49. // Each step of the server pipeline may have dependency on other steps, which includes:
  50. // InitServerContext must be the first step to set contexts for later steps.
  51. // Server Running relies on URL or Message Type Register:
  52. // StartCommunicator---->InitIteration
  53. // Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
  54. // RegisterRoundKernel---->StartCommunicator
  55. // Kernel Initialization relies on Executor Initialization:
  56. // RegisterRoundKernel---->InitExecutor
  57. // Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
  58. // InitCipher---->InitExecutor
  59. void Run();
  60. void SwitchToSafeMode();
  61. void CancelSafeMode();
  62. bool IsSafeMode() const;
  63. void WaitExitSafeMode() const;
  64. // Whether the training job of the server is enabled.
  65. InstanceState instance_state() const;
  66. private:
  67. Server()
  68. : server_node_(nullptr),
  69. task_executor_(nullptr),
  70. use_tcp_(false),
  71. use_http_(false),
  72. http_port_(0),
  73. func_graph_(nullptr),
  74. executor_threshold_(0),
  75. communicator_with_server_(nullptr),
  76. communicators_with_worker_({}),
  77. iteration_(nullptr),
  78. safemode_(true),
  79. server_recovery_(nullptr),
  80. scheduler_ip_(""),
  81. scheduler_port_(0),
  82. server_num_(0),
  83. worker_num_(0),
  84. fl_server_port_(0),
  85. pki_verify_(false),
  86. root_first_ca_path_(""),
  87. root_second_ca_path_(""),
  88. equip_crl_path_(""),
  89. replay_attack_time_diff_(kDefaultReplayAttackTimeDiff),
  90. cipher_initial_client_cnt_(0),
  91. cipher_exchange_keys_cnt_(0),
  92. cipher_get_keys_cnt_(0),
  93. cipher_share_secrets_cnt_(0),
  94. cipher_get_secrets_cnt_(0),
  95. cipher_get_clientlist_cnt_(0),
  96. cipher_push_list_sign_cnt_(0),
  97. cipher_get_list_sign_cnt_(0),
  98. cipher_reconstruct_secrets_up_cnt_(0),
  99. cipher_reconstruct_secrets_down_cnt_(0),
  100. cipher_time_window_(0) {}
  101. ~Server() = default;
  102. Server(const Server &) = delete;
  103. Server &operator=(const Server &) = delete;
  104. // Load variables which is set by ps_context.
  105. void InitServerContext();
  106. // Initialize the server cluster, server node and communicators.
  107. void InitCluster();
  108. bool InitCommunicatorWithServer();
  109. bool InitCommunicatorWithWorker();
  110. // Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
  111. void InitIteration();
  112. // Register all message and event callbacks for communicators(TCP and HTTP). This method must be called before
  113. // communicators are started.
  114. void RegisterCommCallbacks();
  115. // Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
  116. void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
  117. // Register message callbacks. These messages are mainly from scheduler.
  118. void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
  119. // Initialize executor according to the server mode.
  120. void InitExecutor();
  121. // Initialize cipher according to the public param.
  122. void InitCipher();
  123. // Create round kernels and bind these kernels with corresponding Round.
  124. void RegisterRoundKernel();
  125. void InitMetrics();
  126. // The communicators should be started after all initializations are completed.
  127. void StartCommunicator();
  128. // Try to recover server config from persistent storage.
  129. void Recover();
  130. // load pki huks cbg root certificate and crl
  131. void InitPkiCertificate();
  132. // The barriers before scaling operations.
  133. void ProcessBeforeScalingOut();
  134. void ProcessBeforeScalingIn();
  135. // The handlers after scheduler's scaling operations are done.
  136. void ProcessAfterScalingOut();
  137. void ProcessAfterScalingIn();
  138. // Handlers for enableFLS/disableFLS requests from the scheduler.
  139. void HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  140. void HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  141. // Finish current instance and start a new one. FLPlan could be changed in this method.
  142. void HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  143. // Query current instance information.
  144. void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  145. // Synchronize after recovery is completed to ensure consistency.
  146. void HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
  147. // The server node is initialized in Server.
  148. std::shared_ptr<ps::core::ServerNode> server_node_;
  149. // The task executor of the communicators. This helps server to handle network message concurrently. The tasks
  150. // submitted to this task executor is asynchronous.
  151. std::shared_ptr<ps::core::TaskExecutor> task_executor_;
  152. // Which protocol should communicators use.
  153. bool use_tcp_;
  154. bool use_http_;
  155. uint16_t http_port_;
  156. // The configure of all rounds.
  157. std::vector<RoundConfig> rounds_config_;
  158. CipherConfig cipher_config_;
  159. // The graph passed by the frontend without backend optimizing.
  160. FuncGraphPtr func_graph_;
  161. // The threshold count for executor to do aggregation or optimizing.
  162. size_t executor_threshold_;
  163. // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
  164. // operations, etc.
  165. std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_;
  166. // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
  167. // In some cases, both types should be supported in one distributed training job. So here we may have multiple
  168. // communicators.
  169. std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
  170. // Mutex for scaling operations.
  171. std::mutex scaling_mtx_;
  172. // Iteration consists of multiple kinds of rounds.
  173. Iteration *iteration_;
  174. // The flag that represents whether server is in safemode.
  175. // If true, the server is not available to workers and clients.
  176. std::atomic_bool safemode_;
  177. // The recovery object for server.
  178. std::shared_ptr<ServerRecovery> server_recovery_;
  179. // Variables set by ps context.
  180. #ifdef ENABLE_ARMOUR
  181. armour::CipherInit *cipher_init_;
  182. #endif
  183. std::string scheduler_ip_;
  184. uint16_t scheduler_port_;
  185. uint32_t server_num_;
  186. uint32_t worker_num_;
  187. uint16_t fl_server_port_;
  188. bool pki_verify_;
  189. std::string root_first_ca_path_;
  190. std::string root_second_ca_path_;
  191. std::string equip_crl_path_;
  192. uint64_t replay_attack_time_diff_;
  193. size_t cipher_initial_client_cnt_;
  194. size_t cipher_exchange_keys_cnt_;
  195. size_t cipher_get_keys_cnt_;
  196. size_t cipher_share_secrets_cnt_;
  197. size_t cipher_get_secrets_cnt_;
  198. size_t cipher_get_clientlist_cnt_;
  199. size_t cipher_push_list_sign_cnt_;
  200. size_t cipher_get_list_sign_cnt_;
  201. size_t cipher_reconstruct_secrets_up_cnt_;
  202. size_t cipher_reconstruct_secrets_down_cnt_;
  203. uint64_t cipher_time_window_;
  204. };
  205. } // namespace server
  206. } // namespace fl
  207. } // namespace mindspore
  208. #endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_