You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.cc 28 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/server.h"
  17. #include <memory>
  18. #include <string>
  19. #include <csignal>
  20. #ifdef ENABLE_ARMOUR
  21. #include "fl/armour/secure_protocol/secret_sharing.h"
  22. #endif
  23. #include "fl/server/round.h"
  24. #include "fl/server/model_store.h"
  25. #include "fl/server/iteration.h"
  26. #include "fl/server/collective_ops_impl.h"
  27. #include "fl/server/distributed_metadata_store.h"
  28. #include "fl/server/distributed_count_service.h"
  29. #include "fl/server/kernel/round/round_kernel_factory.h"
  30. namespace mindspore {
  31. namespace fl {
  32. namespace server {
  33. // The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster manager like K8S.
  34. std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
  35. std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
  36. void SignalHandler(int signal) {
  37. MS_LOG(WARNING) << "SIGTERM captured: " << signal;
  38. (void)std::for_each(g_communicators_with_worker.begin(), g_communicators_with_worker.end(),
  39. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  40. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  41. (void)communicator->Stop();
  42. });
  43. MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
  44. (void)g_communicator_with_server->Stop();
  45. }
  46. void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
  47. const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
  48. MS_EXCEPTION_IF_NULL(func_graph);
  49. func_graph_ = func_graph;
  50. if (rounds_config.empty()) {
  51. MS_LOG(EXCEPTION) << "Rounds are empty.";
  52. return;
  53. }
  54. rounds_config_ = rounds_config;
  55. cipher_config_ = cipher_config;
  56. use_tcp_ = use_tcp;
  57. use_http_ = use_http;
  58. http_port_ = http_port;
  59. executor_threshold_ = executor_threshold;
  60. (void)signal(SIGTERM, SignalHandler);
  61. return;
  62. }
  63. void Server::Run() {
  64. std::unique_lock<std::mutex> lock(scaling_mtx_);
  65. InitServerContext();
  66. InitPkiCertificate();
  67. InitCluster();
  68. InitIteration();
  69. RegisterCommCallbacks();
  70. StartCommunicator();
  71. InitExecutor();
  72. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  73. if (encrypt_type != ps::kNotEncryptType) {
  74. InitCipher();
  75. MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
  76. }
  77. RegisterRoundKernel();
  78. InitMetrics();
  79. Recover();
  80. MS_LOG(INFO) << "Server started successfully.";
  81. safemode_ = false;
  82. lock.unlock();
  83. // Wait communicators to stop so the main thread is blocked.
  84. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  85. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  86. MS_EXCEPTION_IF_NULL(communicator);
  87. communicator->Join();
  88. });
  89. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  90. communicator_with_server_->Join();
  91. MsException::Instance().CheckException();
  92. func_graph_ = nullptr;
  93. return;
  94. }
  95. void Server::InitPkiCertificate() {
  96. if (ps::PSContext::instance()->pki_verify()) {
  97. root_first_ca_path_ = ps::PSContext::instance()->root_first_ca_path();
  98. root_second_ca_path_ = ps::PSContext::instance()->root_second_ca_path();
  99. equip_crl_path_ = ps::PSContext::instance()->equip_crl_path();
  100. replay_attack_time_diff_ = ps::PSContext::instance()->replay_attack_time_diff();
  101. bool ret = mindspore::ps::server::CertVerify::initRootCertAndCRL(root_first_ca_path_, root_second_ca_path_,
  102. equip_crl_path_, replay_attack_time_diff_);
  103. if (!ret) {
  104. MS_LOG(EXCEPTION) << "init root cert and crl failed.";
  105. return;
  106. }
  107. return;
  108. }
  109. }
  110. void Server::SwitchToSafeMode() {
  111. MS_LOG(INFO) << "Server switch to safemode.";
  112. safemode_ = true;
  113. }
  114. void Server::CancelSafeMode() {
  115. MS_LOG(INFO) << "Server cancel safemode.";
  116. safemode_ = false;
  117. }
  118. bool Server::IsSafeMode() const { return safemode_.load(); }
  119. void Server::WaitExitSafeMode() const {
  120. while (safemode_.load()) {
  121. std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
  122. }
  123. }
  124. void Server::InitServerContext() {
  125. ps::PSContext::instance()->GenerateResetterRound();
  126. scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
  127. scheduler_port_ = ps::PSContext::instance()->scheduler_port();
  128. worker_num_ = ps::PSContext::instance()->initial_worker_num();
  129. server_num_ = ps::PSContext::instance()->initial_server_num();
  130. return;
  131. }
  132. void Server::InitCluster() {
  133. server_node_ = std::make_shared<ps::core::ServerNode>();
  134. MS_EXCEPTION_IF_NULL(server_node_);
  135. task_executor_ = std::make_shared<ps::core::TaskExecutor>(kExecutorThreadPoolSize);
  136. MS_EXCEPTION_IF_NULL(task_executor_);
  137. if (!InitCommunicatorWithServer()) {
  138. MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
  139. return;
  140. }
  141. if (!InitCommunicatorWithWorker()) {
  142. MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
  143. return;
  144. }
  145. return;
  146. }
  147. bool Server::InitCommunicatorWithServer() {
  148. MS_EXCEPTION_IF_NULL(task_executor_);
  149. MS_EXCEPTION_IF_NULL(server_node_);
  150. communicator_with_server_ = server_node_->GetOrCreateTcpComm(scheduler_ip_, static_cast<int16_t>(scheduler_port_),
  151. worker_num_, server_num_, task_executor_);
  152. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  153. g_communicator_with_server = communicator_with_server_;
  154. return true;
  155. }
  156. bool Server::InitCommunicatorWithWorker() {
  157. MS_EXCEPTION_IF_NULL(server_node_);
  158. MS_EXCEPTION_IF_NULL(task_executor_);
  159. if (!use_tcp_ && !use_http_) {
  160. MS_LOG(EXCEPTION) << "At least one type of protocol should be set.";
  161. return false;
  162. }
  163. if (use_tcp_) {
  164. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  165. auto tcp_comm = communicator_with_server_;
  166. MS_EXCEPTION_IF_NULL(tcp_comm);
  167. communicators_with_worker_.push_back(tcp_comm);
  168. }
  169. if (use_http_) {
  170. auto http_comm = server_node_->GetOrCreateHttpComm(server_node_->BoundIp(), http_port_, task_executor_);
  171. MS_EXCEPTION_IF_NULL(http_comm);
  172. communicators_with_worker_.push_back(http_comm);
  173. }
  174. g_communicators_with_worker = communicators_with_worker_;
  175. return true;
  176. }
  177. void Server::InitIteration() {
  178. iteration_ = &Iteration::GetInstance();
  179. MS_EXCEPTION_IF_NULL(iteration_);
  180. // 1.Add rounds to the iteration according to the server mode.
  181. for (const RoundConfig &config : rounds_config_) {
  182. std::shared_ptr<Round> round =
  183. std::make_shared<Round>(config.name, config.check_timeout, config.time_window, config.check_count,
  184. config.threshold_count, config.server_num_as_threshold);
  185. MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
  186. << ", time window: " << config.time_window << ", check_count: " << config.check_count
  187. << ", threshold: " << config.threshold_count
  188. << ", server_num_as_threshold: " << config.server_num_as_threshold;
  189. iteration_->AddRound(round);
  190. }
  191. #ifdef ENABLE_ARMOUR
  192. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  193. if (encrypt_type == ps::kPWEncryptType) {
  194. cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold;
  195. cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold;
  196. cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
  197. cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
  198. cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
  199. cipher_push_list_sign_cnt_ = cipher_config_.push_list_sign_threshold;
  200. cipher_get_list_sign_cnt_ = cipher_config_.get_list_sign_threshold;
  201. cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
  202. cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
  203. cipher_time_window_ = cipher_config_.cipher_time_window;
  204. MS_LOG(INFO) << "Initializing cipher:";
  205. MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_
  206. << " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_
  207. << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
  208. MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
  209. << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
  210. << " cipher_push_list_sign_cnt_: " << cipher_push_list_sign_cnt_
  211. << " cipher_get_list_sign_cnt_: " << cipher_get_list_sign_cnt_
  212. << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
  213. << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
  214. << " cipher_time_window_: " << cipher_time_window_;
  215. }
  216. #endif
  217. // 2.Initialize all the rounds.
  218. TimeOutCb time_out_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
  219. FinishIterCb finish_iter_cb =
  220. std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
  221. iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
  222. return;
  223. }
  224. void Server::InitCipher() {
  225. #ifdef ENABLE_ARMOUR
  226. cipher_init_ = &armour::CipherInit::GetInstance();
  227. int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
  228. unsigned char cipher_p[SECRET_MAX_LEN] = {0};
  229. const int cipher_g = 1;
  230. float dp_eps = ps::PSContext::instance()->dp_eps();
  231. float dp_delta = ps::PSContext::instance()->dp_delta();
  232. float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
  233. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  234. mindspore::armour::CipherPublicPara param;
  235. param.g = cipher_g;
  236. param.t = cipher_t;
  237. int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p));
  238. if (ret != 0) {
  239. MS_LOG(EXCEPTION) << "Memcpy_s error, errorno" << ret;
  240. }
  241. param.dp_delta = dp_delta;
  242. param.dp_eps = dp_eps;
  243. param.dp_norm_clip = dp_norm_clip;
  244. param.encrypt_type = encrypt_type;
  245. BIGNUM *prim = BN_new();
  246. if (prim == NULL) {
  247. MS_LOG(EXCEPTION) << "new bn failed.";
  248. ret = -1;
  249. } else {
  250. ret = mindspore::armour::GetPrime(prim);
  251. }
  252. if (ret == 0) {
  253. (void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(param.prime));
  254. } else {
  255. MS_LOG(EXCEPTION) << "Get prime failed.";
  256. }
  257. if (prim != NULL) {
  258. BN_clear_free(prim);
  259. }
  260. cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
  261. cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_push_list_sign_cnt_,
  262. cipher_get_list_sign_cnt_, cipher_reconstruct_secrets_up_cnt_);
  263. #endif
  264. }
  265. void Server::RegisterCommCallbacks() {
  266. // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
  267. // rounds' callbacks.
  268. MS_EXCEPTION_IF_NULL(server_node_);
  269. MS_EXCEPTION_IF_NULL(iteration_);
  270. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  271. MS_EXCEPTION_IF_NULL(tcp_comm);
  272. // Set message callbacks for server-to-server communication.
  273. DistributedMetadataStore::GetInstance().RegisterMessageCallback(tcp_comm);
  274. DistributedCountService::GetInstance().RegisterMessageCallback(tcp_comm);
  275. iteration_->RegisterMessageCallback(tcp_comm);
  276. iteration_->RegisterEventCallback(server_node_);
  277. // Set exception event callbacks for server.
  278. RegisterExceptionEventCallback(tcp_comm);
  279. // Set message callbacks for server.
  280. RegisterMessageCallback(tcp_comm);
  281. if (!server_node_->InitFollowerScaler()) {
  282. MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
  283. return;
  284. }
  285. // Set scaling barriers before scaling.
  286. server_node_->RegisterFollowerScalerBarrierBeforeScaleOut("ServerPipeline",
  287. std::bind(&Server::ProcessBeforeScalingOut, this));
  288. server_node_->RegisterFollowerScalerBarrierBeforeScaleIn("ServerPipeline",
  289. std::bind(&Server::ProcessBeforeScalingIn, this));
  290. // Set handlers after scheduler scaling operations are done.
  291. server_node_->RegisterFollowerScalerHandlerAfterScaleOut("ServerPipeline",
  292. std::bind(&Server::ProcessAfterScalingOut, this));
  293. server_node_->RegisterFollowerScalerHandlerAfterScaleIn("ServerPipeline",
  294. std::bind(&Server::ProcessAfterScalingIn, this));
  295. }
  296. void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
  297. MS_EXCEPTION_IF_NULL(communicator);
  298. communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
  299. MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
  300. safemode_ = true;
  301. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  302. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  303. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  304. (void)communicator->Stop();
  305. });
  306. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  307. (void)communicator_with_server_->Stop();
  308. });
  309. communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
  310. MS_LOG(ERROR)
  311. << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
  312. "network building phase.";
  313. safemode_ = true;
  314. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  315. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  316. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  317. (void)communicator->Stop();
  318. });
  319. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  320. (void)communicator_with_server_->Stop();
  321. });
  322. }
  323. void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
  324. MS_EXCEPTION_IF_NULL(communicator);
  325. // Register handler for restful requests receviced by scheduler.
  326. communicator->RegisterMsgCallBack("enableFLS",
  327. std::bind(&Server::HandleEnableServerRequest, this, std::placeholders::_1));
  328. communicator->RegisterMsgCallBack("disableFLS",
  329. std::bind(&Server::HandleDisableServerRequest, this, std::placeholders::_1));
  330. communicator->RegisterMsgCallBack("newInstance",
  331. std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1));
  332. communicator->RegisterMsgCallBack("queryInstance",
  333. std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
  334. communicator->RegisterMsgCallBack("syncAfterRecover",
  335. std::bind(&Server::HandleSyncAfterRecoveryRequest, this, std::placeholders::_1));
  336. }
  337. void Server::InitExecutor() {
  338. MS_EXCEPTION_IF_NULL(func_graph_);
  339. if (executor_threshold_ == 0) {
  340. MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
  341. return;
  342. }
  343. // The train engine instance is used in both push-type and pull-type kernels,
  344. // so the required_cnt of these kernels must be the same as executor_threshold_.
  345. MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
  346. Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
  347. ModelStore::GetInstance().Initialize();
  348. return;
  349. }
  350. void Server::RegisterRoundKernel() {
  351. MS_EXCEPTION_IF_NULL(iteration_);
  352. auto &rounds = iteration_->rounds();
  353. if (rounds.empty()) {
  354. MS_LOG(EXCEPTION) << "Server has no round registered.";
  355. return;
  356. }
  357. for (auto &round : rounds) {
  358. MS_EXCEPTION_IF_NULL(round);
  359. const std::string &name = round->name();
  360. std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
  361. if (round_kernel == nullptr) {
  362. MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered.";
  363. return;
  364. }
  365. // For some round kernels, the threshold count should be set.
  366. round_kernel->InitKernel(round->threshold_count());
  367. round->BindRoundKernel(round_kernel);
  368. }
  369. return;
  370. }
  371. void Server::InitMetrics() {
  372. if (server_node_->rank_id() == kLeaderServerRank) {
  373. MS_EXCEPTION_IF_NULL(iteration_);
  374. std::shared_ptr<IterationMetrics> iteration_metrics =
  375. std::make_shared<IterationMetrics>(ps::PSContext::instance()->config_file_path());
  376. if (!iteration_metrics->Initialize()) {
  377. MS_LOG(WARNING) << "Initializing metrics failed.";
  378. return;
  379. }
  380. iteration_->set_metrics(iteration_metrics);
  381. }
  382. }
  383. void Server::StartCommunicator() {
  384. if (communicators_with_worker_.empty()) {
  385. MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
  386. return;
  387. }
  388. MS_EXCEPTION_IF_NULL(server_node_);
  389. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  390. MS_LOG(INFO) << "Start communicator with server.";
  391. if (!communicator_with_server_->Start()) {
  392. MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
  393. }
  394. DistributedMetadataStore::GetInstance().Initialize(server_node_);
  395. CollectiveOpsImpl::GetInstance().Initialize(server_node_);
  396. DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
  397. MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
  398. MS_LOG(INFO) << "Start communicator with worker.";
  399. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  400. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  401. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  402. if (!communicator->Start()) {
  403. MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
  404. }
  405. });
  406. }
  407. void Server::Recover() {
  408. server_recovery_ = std::make_shared<ServerRecovery>();
  409. MS_EXCEPTION_IF_NULL(server_recovery_);
  410. // Try to recovery from persistent storage.
  411. if (!server_recovery_->Initialize(ps::PSContext::instance()->config_file_path())) {
  412. MS_LOG(WARNING) << "Initializing server recovery failed. Do not recover for this server.";
  413. return;
  414. }
  415. if (server_recovery_->Recover()) {
  416. // If this server recovers, need to notify cluster to reach consistency.
  417. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  418. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  419. MS_LOG(INFO) << "Synchronize with leader server after recovery.";
  420. if (!server_recovery_->SyncAfterRecovery(tcp_comm, server_node_->rank_id())) {
  421. MS_LOG(EXCEPTION) << "Failed to reach consistency of the cluster after recovery.";
  422. return;
  423. }
  424. }
  425. // Set the recovery handler to Iteration.
  426. MS_EXCEPTION_IF_NULL(iteration_);
  427. iteration_->set_recovery_handler(server_recovery_);
  428. }
  429. void Server::ProcessBeforeScalingOut() {
  430. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  431. iteration_->ScalingBarrier();
  432. safemode_ = true;
  433. }
  434. void Server::ProcessBeforeScalingIn() {
  435. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  436. iteration_->ScalingBarrier();
  437. safemode_ = true;
  438. }
  439. void Server::ProcessAfterScalingOut() {
  440. std::unique_lock<std::mutex> lock(scaling_mtx_);
  441. MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
  442. if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
  443. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  444. }
  445. if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
  446. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  447. }
  448. if (!DistributedCountService::GetInstance().ReInitForScaling()) {
  449. MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
  450. }
  451. if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
  452. MS_LOG(WARNING) << "Iteration reinitializing failed.";
  453. }
  454. if (!Executor::GetInstance().ReInitForScaling()) {
  455. MS_LOG(WARNING) << "Executor reinitializing failed.";
  456. }
  457. std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
  458. safemode_ = false;
  459. }
  460. void Server::ProcessAfterScalingIn() {
  461. std::unique_lock<std::mutex> lock(scaling_mtx_);
  462. MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
  463. if (server_node_->rank_id() == UINT32_MAX) {
  464. MS_LOG(WARNING) << "This server the one to be scaled in. Server need to wait SIGTERM to exit.";
  465. return;
  466. }
  467. // If the server is not the one to be scaled in, reintialize modules and recover service.
  468. if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
  469. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  470. }
  471. if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
  472. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  473. }
  474. if (!DistributedCountService::GetInstance().ReInitForScaling()) {
  475. MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
  476. }
  477. if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
  478. MS_LOG(WARNING) << "Iteration reinitializing failed.";
  479. }
  480. if (!Executor::GetInstance().ReInitForScaling()) {
  481. MS_LOG(WARNING) << "Executor reinitializing failed.";
  482. }
  483. std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
  484. safemode_ = false;
  485. }
  486. void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  487. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  488. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  489. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  490. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  491. std::string result_message = "";
  492. bool result = iteration_->EnableServerInstance(&result_message);
  493. nlohmann::json response;
  494. response["result"] = result;
  495. response["message"] = result_message;
  496. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  497. MS_LOG(ERROR) << "Sending response failed.";
  498. return;
  499. }
  500. }
  501. void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  502. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  503. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  504. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  505. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  506. std::string result_message = "";
  507. bool result = iteration_->DisableServerInstance(&result_message);
  508. nlohmann::json response;
  509. response["result"] = result;
  510. response["message"] = result_message;
  511. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  512. MS_LOG(ERROR) << "Sending response failed.";
  513. return;
  514. }
  515. }
  516. void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  517. MS_ERROR_IF_NULL_WO_RET_VAL(message);
  518. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  519. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  520. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  521. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  522. std::string hyper_params_str(static_cast<const char *>(message->data()), message->len());
  523. nlohmann::json new_instance_json;
  524. nlohmann::json response;
  525. try {
  526. new_instance_json = nlohmann::json::parse(hyper_params_str);
  527. } catch (const std::exception &e) {
  528. response["result"] = false;
  529. response["message"] = "The hyper-parameter data is not in json format.";
  530. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  531. MS_LOG(ERROR) << "Sending response failed.";
  532. return;
  533. }
  534. }
  535. std::string result_message = "";
  536. bool result = iteration_->NewInstance(new_instance_json, &result_message);
  537. response["result"] = result;
  538. response["message"] = result_message;
  539. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  540. MS_LOG(ERROR) << "Sending response failed.";
  541. return;
  542. }
  543. }
  544. void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  545. MS_ERROR_IF_NULL_WO_RET_VAL(message);
  546. nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> response;
  547. response["start_fl_job_threshold"] = ps::PSContext::instance()->start_fl_job_threshold();
  548. response["start_fl_job_time_window"] = ps::PSContext::instance()->start_fl_job_time_window();
  549. response["update_model_ratio"] = ps::PSContext::instance()->update_model_ratio();
  550. response["update_model_time_window"] = ps::PSContext::instance()->update_model_time_window();
  551. response["fl_iteration_num"] = ps::PSContext::instance()->fl_iteration_num();
  552. response["client_epoch_num"] = ps::PSContext::instance()->client_epoch_num();
  553. response["client_batch_size"] = ps::PSContext::instance()->client_batch_size();
  554. response["client_learning_rate"] = ps::PSContext::instance()->client_learning_rate();
  555. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  556. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  557. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  558. MS_LOG(ERROR) << "Sending response failed.";
  559. return;
  560. }
  561. }
  562. void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  563. MS_ERROR_IF_NULL_WO_RET_VAL(message);
  564. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  565. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  566. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  567. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  568. MS_LOG(INFO) << "Receive SyncAfterRecover request from other server.";
  569. std::string response = "success";
  570. if (!tcp_comm->SendResponse(response.c_str(), response.size(), message)) {
  571. MS_LOG(ERROR) << "Sending response of SyncAfterRecoverRequest failed.";
  572. return;
  573. }
  574. if (!safemode_.load()) {
  575. MS_LOG(INFO) << "Need to synchronize for other server's recovery";
  576. SyncAfterRecover sync_after_recovery_req;
  577. (void)sync_after_recovery_req.ParseFromArray(message->data(), SizeToInt(message->len()));
  578. if (!iteration_->SyncAfterRecovery(sync_after_recovery_req.current_iter_num())) {
  579. MS_LOG(ERROR) << "Sync after recovery failed.";
  580. return;
  581. }
  582. }
  583. }
  584. } // namespace server
  585. } // namespace fl
  586. } // namespace mindspore