You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.cc 27 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/server.h"
  17. #include <memory>
  18. #include <string>
  19. #include <csignal>
  20. #ifdef ENABLE_ARMOUR
  21. #include "fl/armour/secure_protocol/secret_sharing.h"
  22. #endif
  23. #include "fl/server/round.h"
  24. #include "fl/server/model_store.h"
  25. #include "fl/server/iteration.h"
  26. #include "fl/server/collective_ops_impl.h"
  27. #include "fl/server/distributed_metadata_store.h"
  28. #include "fl/server/distributed_count_service.h"
  29. #include "fl/server/kernel/round/round_kernel_factory.h"
  30. namespace mindspore {
  31. namespace fl {
  32. namespace server {
  33. // The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S.
  34. std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
  35. std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
  36. void SignalHandler(int signal) {
  37. MS_LOG(WARNING) << "SIGTERM captured: " << signal;
  38. (void)std::for_each(g_communicators_with_worker.begin(), g_communicators_with_worker.end(),
  39. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  40. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  41. (void)communicator->Stop();
  42. });
  43. MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
  44. (void)g_communicator_with_server->Stop();
  45. return;
  46. }
  47. void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
  48. const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
  49. MS_EXCEPTION_IF_NULL(func_graph);
  50. func_graph_ = func_graph;
  51. if (rounds_config.empty()) {
  52. MS_LOG(EXCEPTION) << "Rounds are empty.";
  53. return;
  54. }
  55. rounds_config_ = rounds_config;
  56. cipher_config_ = cipher_config;
  57. use_tcp_ = use_tcp;
  58. use_http_ = use_http;
  59. http_port_ = http_port;
  60. executor_threshold_ = executor_threshold;
  61. signal(SIGTERM, SignalHandler);
  62. return;
  63. }
  64. // Each step of the server pipeline may have dependency on other steps, which includes:
  65. // InitServerContext must be the first step to set contexts for later steps.
  66. // Server Running relies on URL or Message Type Register:
  67. // StartCommunicator---->InitIteration
  68. // Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
  69. // RegisterRoundKernel---->StartCommunicator
  70. // Kernel Initialization relies on Executor Initialization:
  71. // RegisterRoundKernel---->InitExecutor
  72. // Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
  73. // InitCipher---->InitExecutor
  74. void Server::Run() {
  75. std::unique_lock<std::mutex> lock(scaling_mtx_);
  76. InitServerContext();
  77. InitCluster();
  78. InitIteration();
  79. RegisterCommCallbacks();
  80. StartCommunicator();
  81. InitExecutor();
  82. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  83. if (encrypt_type != ps::kNotEncryptType) {
  84. InitCipher();
  85. MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
  86. }
  87. RegisterRoundKernel();
  88. InitMetrics();
  89. MS_LOG(INFO) << "Server started successfully.";
  90. safemode_ = false;
  91. lock.unlock();
  92. // Wait communicators to stop so the main thread is blocked.
  93. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  94. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  95. MS_EXCEPTION_IF_NULL(communicator);
  96. communicator->Join();
  97. });
  98. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  99. communicator_with_server_->Join();
  100. MsException::Instance().CheckException();
  101. return;
  102. }
  103. void Server::SwitchToSafeMode() {
  104. MS_LOG(INFO) << "Server switch to safemode.";
  105. safemode_ = true;
  106. }
  107. void Server::CancelSafeMode() {
  108. MS_LOG(INFO) << "Server cancel safemode.";
  109. safemode_ = false;
  110. }
  111. bool Server::IsSafeMode() const { return safemode_.load(); }
  112. void Server::WaitExitSafeMode() const {
  113. while (safemode_.load()) {
  114. std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
  115. }
  116. }
  117. void Server::InitServerContext() {
  118. ps::PSContext::instance()->GenerateResetterRound();
  119. scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
  120. scheduler_port_ = ps::PSContext::instance()->scheduler_port();
  121. worker_num_ = ps::PSContext::instance()->initial_worker_num();
  122. server_num_ = ps::PSContext::instance()->initial_server_num();
  123. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  124. if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) {
  125. MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_;
  126. return;
  127. }
  128. return;
  129. }
  130. void Server::InitCluster() {
  131. server_node_ = std::make_shared<ps::core::ServerNode>();
  132. MS_EXCEPTION_IF_NULL(server_node_);
  133. task_executor_ = std::make_shared<ps::core::TaskExecutor>(kExecutorThreadPoolSize);
  134. MS_EXCEPTION_IF_NULL(task_executor_);
  135. if (!InitCommunicatorWithServer()) {
  136. MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
  137. return;
  138. }
  139. if (!InitCommunicatorWithWorker()) {
  140. MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
  141. return;
  142. }
  143. return;
  144. }
  145. bool Server::InitCommunicatorWithServer() {
  146. MS_EXCEPTION_IF_NULL(task_executor_);
  147. MS_EXCEPTION_IF_NULL(server_node_);
  148. communicator_with_server_ =
  149. server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
  150. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  151. g_communicator_with_server = communicator_with_server_;
  152. return true;
  153. }
  154. bool Server::InitCommunicatorWithWorker() {
  155. MS_EXCEPTION_IF_NULL(server_node_);
  156. MS_EXCEPTION_IF_NULL(task_executor_);
  157. if (!use_tcp_ && !use_http_) {
  158. MS_LOG(EXCEPTION) << "At least one type of protocol should be set.";
  159. return false;
  160. }
  161. if (use_tcp_) {
  162. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  163. auto tcp_comm = communicator_with_server_;
  164. MS_EXCEPTION_IF_NULL(tcp_comm);
  165. communicators_with_worker_.push_back(tcp_comm);
  166. }
  167. if (use_http_) {
  168. auto http_comm = server_node_->GetOrCreateHttpComm(server_node_->BoundIp(), http_port_, task_executor_);
  169. MS_EXCEPTION_IF_NULL(http_comm);
  170. communicators_with_worker_.push_back(http_comm);
  171. }
  172. g_communicators_with_worker = communicators_with_worker_;
  173. return true;
  174. }
  175. void Server::InitIteration() {
  176. iteration_ = &Iteration::GetInstance();
  177. MS_EXCEPTION_IF_NULL(iteration_);
  178. // 1.Add rounds to the iteration according to the server mode.
  179. for (const RoundConfig &config : rounds_config_) {
  180. std::shared_ptr<Round> round =
  181. std::make_shared<Round>(config.name, config.check_timeout, config.time_window, config.check_count,
  182. config.threshold_count, config.server_num_as_threshold);
  183. MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
  184. << ", time window: " << config.time_window << ", check_count: " << config.check_count
  185. << ", threshold: " << config.threshold_count
  186. << ", server_num_as_threshold: " << config.server_num_as_threshold;
  187. iteration_->AddRound(round);
  188. }
  189. #ifdef ENABLE_ARMOUR
  190. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  191. if (encrypt_type == ps::kPWEncryptType) {
  192. cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
  193. cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
  194. cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
  195. cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
  196. cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
  197. cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold;
  198. cipher_time_window_ = cipher_config_.cipher_time_window;
  199. MS_LOG(INFO) << "Initializing cipher:";
  200. MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
  201. << " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
  202. << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
  203. MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
  204. << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
  205. << " cipher_time_window_: " << cipher_time_window_
  206. << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
  207. std::shared_ptr<Round> exchange_keys_round =
  208. std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
  209. MS_EXCEPTION_IF_NULL(exchange_keys_round);
  210. iteration_->AddRound(exchange_keys_round);
  211. std::shared_ptr<Round> get_keys_round =
  212. std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
  213. MS_EXCEPTION_IF_NULL(get_keys_round);
  214. iteration_->AddRound(get_keys_round);
  215. std::shared_ptr<Round> share_secrets_round =
  216. std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
  217. MS_EXCEPTION_IF_NULL(share_secrets_round);
  218. iteration_->AddRound(share_secrets_round);
  219. std::shared_ptr<Round> get_secrets_round =
  220. std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
  221. MS_EXCEPTION_IF_NULL(get_secrets_round);
  222. iteration_->AddRound(get_secrets_round);
  223. std::shared_ptr<Round> get_clientlist_round =
  224. std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
  225. MS_EXCEPTION_IF_NULL(get_clientlist_round);
  226. iteration_->AddRound(get_clientlist_round);
  227. std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
  228. "reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
  229. MS_EXCEPTION_IF_NULL(reconstruct_secrets_round);
  230. iteration_->AddRound(reconstruct_secrets_round);
  231. MS_LOG(INFO) << "Cipher rounds has been added.";
  232. }
  233. #endif
  234. // 2.Initialize all the rounds.
  235. TimeOutCb time_out_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
  236. FinishIterCb finish_iter_cb =
  237. std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
  238. iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
  239. return;
  240. }
  241. void Server::InitCipher() {
  242. #ifdef ENABLE_ARMOUR
  243. cipher_init_ = &armour::CipherInit::GetInstance();
  244. int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
  245. unsigned char cipher_p[SECRET_MAX_LEN] = {0};
  246. const int cipher_g = 1;
  247. unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
  248. float dp_eps = ps::PSContext::instance()->dp_eps();
  249. float dp_delta = ps::PSContext::instance()->dp_delta();
  250. float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
  251. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  252. mpz_t prim;
  253. mpz_init(prim);
  254. mindspore::armour::GetRandomPrime(prim);
  255. mindspore::armour::PrintBigInteger(prim, 16);
  256. size_t len_cipher_prime;
  257. mpz_export((unsigned char *)cipher_prime, &len_cipher_prime, sizeof(unsigned char), 1, 0, 0, prim);
  258. mindspore::armour::CipherPublicPara param;
  259. param.g = cipher_g;
  260. param.t = cipher_t;
  261. int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
  262. if (ret != 0) {
  263. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  264. return;
  265. }
  266. ret = memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
  267. if (ret != 0) {
  268. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  269. return;
  270. }
  271. param.dp_delta = dp_delta;
  272. param.dp_eps = dp_eps;
  273. param.dp_norm_clip = dp_norm_clip;
  274. param.encrypt_type = encrypt_type;
  275. cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
  276. cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
  277. cipher_reconstruct_secrets_up_cnt_);
  278. #endif
  279. }
  280. void Server::RegisterCommCallbacks() {
  281. // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
  282. // rounds' callbacks.
  283. MS_EXCEPTION_IF_NULL(server_node_);
  284. MS_EXCEPTION_IF_NULL(iteration_);
  285. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  286. MS_EXCEPTION_IF_NULL(tcp_comm);
  287. // Set message callbacks for server-to-server communication.
  288. DistributedMetadataStore::GetInstance().RegisterMessageCallback(tcp_comm);
  289. DistributedCountService::GetInstance().RegisterMessageCallback(tcp_comm);
  290. iteration_->RegisterMessageCallback(tcp_comm);
  291. iteration_->RegisterEventCallback(server_node_);
  292. // Set exception event callbacks for server.
  293. RegisterExceptionEventCallback(tcp_comm);
  294. // Set message callbacks for server.
  295. RegisterMessageCallback(tcp_comm);
  296. if (!server_node_->InitFollowerScaler()) {
  297. MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
  298. return;
  299. }
  300. // Set scaling barriers before scaling.
  301. server_node_->RegisterFollowerScalerBarrierBeforeScaleOut("ServerPipeline",
  302. std::bind(&Server::ProcessBeforeScalingOut, this));
  303. server_node_->RegisterFollowerScalerBarrierBeforeScaleIn("ServerPipeline",
  304. std::bind(&Server::ProcessBeforeScalingIn, this));
  305. // Set handlers after scheduler scaling operations are done.
  306. server_node_->RegisterFollowerScalerHandlerAfterScaleOut("ServerPipeline",
  307. std::bind(&Server::ProcessAfterScalingOut, this));
  308. server_node_->RegisterFollowerScalerHandlerAfterScaleIn("ServerPipeline",
  309. std::bind(&Server::ProcessAfterScalingIn, this));
  310. }
  311. void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
  312. MS_EXCEPTION_IF_NULL(communicator);
  313. communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
  314. MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
  315. safemode_ = true;
  316. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  317. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  318. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  319. (void)communicator->Stop();
  320. });
  321. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  322. (void)communicator_with_server_->Stop();
  323. });
  324. communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
  325. MS_LOG(ERROR)
  326. << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
  327. "network building phase.";
  328. safemode_ = true;
  329. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  330. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  331. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  332. (void)communicator->Stop();
  333. });
  334. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  335. (void)communicator_with_server_->Stop();
  336. });
  337. }
  338. void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
  339. MS_EXCEPTION_IF_NULL(communicator);
  340. // Register handler for restful requests receviced by scheduler.
  341. communicator->RegisterMsgCallBack("enableFLS",
  342. std::bind(&Server::HandleEnableServerRequest, this, std::placeholders::_1));
  343. communicator->RegisterMsgCallBack("disableFLS",
  344. std::bind(&Server::HandleDisableServerRequest, this, std::placeholders::_1));
  345. communicator->RegisterMsgCallBack("newInstance",
  346. std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1));
  347. communicator->RegisterMsgCallBack("queryInstance",
  348. std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
  349. }
  350. void Server::InitExecutor() {
  351. MS_EXCEPTION_IF_NULL(func_graph_);
  352. if (executor_threshold_ == 0) {
  353. MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
  354. return;
  355. }
  356. // The train engine instance is used in both push-type and pull-type kernels,
  357. // so the required_cnt of these kernels must be the same as executor_threshold_.
  358. MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
  359. Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
  360. ModelStore::GetInstance().Initialize();
  361. return;
  362. }
  363. void Server::RegisterRoundKernel() {
  364. MS_EXCEPTION_IF_NULL(iteration_);
  365. auto &rounds = iteration_->rounds();
  366. if (rounds.empty()) {
  367. MS_LOG(EXCEPTION) << "Server has no round registered.";
  368. return;
  369. }
  370. for (auto &round : rounds) {
  371. MS_EXCEPTION_IF_NULL(round);
  372. const std::string &name = round->name();
  373. std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
  374. if (round_kernel == nullptr) {
  375. MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered.";
  376. return;
  377. }
  378. // For some round kernels, the threshold count should be set.
  379. round_kernel->InitKernel(round->threshold_count());
  380. round->BindRoundKernel(round_kernel);
  381. }
  382. return;
  383. }
  384. void Server::InitMetrics() {
  385. if (server_node_->rank_id() == kLeaderServerRank) {
  386. MS_EXCEPTION_IF_NULL(iteration_);
  387. std::shared_ptr<IterationMetrics> iteration_metrics =
  388. std::make_shared<IterationMetrics>(ps::PSContext::instance()->config_file_path());
  389. if (!iteration_metrics->Initialize()) {
  390. MS_LOG(WARNING) << "Initializing metrics failed.";
  391. return;
  392. }
  393. iteration_->set_metrics(iteration_metrics);
  394. }
  395. }
  396. void Server::StartCommunicator() {
  397. if (communicators_with_worker_.empty()) {
  398. MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
  399. return;
  400. }
  401. MS_EXCEPTION_IF_NULL(server_node_);
  402. MS_EXCEPTION_IF_NULL(communicator_with_server_);
  403. MS_LOG(INFO) << "Start communicator with server.";
  404. if (!communicator_with_server_->Start()) {
  405. MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
  406. return;
  407. }
  408. DistributedMetadataStore::GetInstance().Initialize(server_node_);
  409. CollectiveOpsImpl::GetInstance().Initialize(server_node_);
  410. DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
  411. MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
  412. MS_LOG(INFO) << "Start communicator with worker.";
  413. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
  414. [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
  415. MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
  416. if (!communicator->Start()) {
  417. MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
  418. }
  419. });
  420. }
  421. void Server::ProcessBeforeScalingOut() {
  422. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  423. iteration_->ScalingBarrier();
  424. safemode_ = true;
  425. }
  426. void Server::ProcessBeforeScalingIn() {
  427. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  428. iteration_->ScalingBarrier();
  429. safemode_ = true;
  430. }
  431. void Server::ProcessAfterScalingOut() {
  432. std::unique_lock<std::mutex> lock(scaling_mtx_);
  433. MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
  434. if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
  435. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  436. }
  437. if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
  438. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  439. }
  440. if (!DistributedCountService::GetInstance().ReInitForScaling()) {
  441. MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
  442. }
  443. if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
  444. MS_LOG(WARNING) << "Iteration reinitializing failed.";
  445. }
  446. if (!Executor::GetInstance().ReInitForScaling()) {
  447. MS_LOG(WARNING) << "Executor reinitializing failed.";
  448. }
  449. std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
  450. safemode_ = false;
  451. }
  452. void Server::ProcessAfterScalingIn() {
  453. std::unique_lock<std::mutex> lock(scaling_mtx_);
  454. MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
  455. if (server_node_->rank_id() == UINT32_MAX) {
  456. MS_LOG(WARNING) << "This server the one to be scaled in. Server need to wait SIGTERM to exit.";
  457. return;
  458. }
  459. // If the server is not the one to be scaled in, reintialize modules and recover service.
  460. if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
  461. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  462. }
  463. if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
  464. MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
  465. }
  466. if (!DistributedCountService::GetInstance().ReInitForScaling()) {
  467. MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
  468. }
  469. if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
  470. MS_LOG(WARNING) << "Iteration reinitializing failed.";
  471. }
  472. if (!Executor::GetInstance().ReInitForScaling()) {
  473. MS_LOG(WARNING) << "Executor reinitializing failed.";
  474. }
  475. std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
  476. safemode_ = false;
  477. }
  478. void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  479. MS_ERROR_IF_NULL_WO_RET_VAL(message);
  480. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  481. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  482. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  483. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  484. std::string result_message = "";
  485. bool result = iteration_->EnableServerInstance(&result_message);
  486. nlohmann::json response;
  487. response["result"] = result;
  488. response["message"] = result_message;
  489. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  490. MS_LOG(ERROR) << "Sending response failed.";
  491. return;
  492. }
  493. }
  494. void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  495. MS_ERROR_IF_NULL_WO_RET_VAL(message);
  496. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  497. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  498. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  499. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  500. std::string result_message = "";
  501. bool result = iteration_->DisableServerInstance(&result_message);
  502. nlohmann::json response;
  503. response["result"] = result;
  504. response["message"] = result_message;
  505. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  506. MS_LOG(ERROR) << "Sending response failed.";
  507. return;
  508. }
  509. }
  510. void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  511. MS_ERROR_IF_NULL_WO_RET_VAL(message);
  512. MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
  513. MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
  514. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  515. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  516. std::string hyper_params_str(static_cast<const char *>(message->data()), message->len());
  517. nlohmann::json new_instance_json;
  518. nlohmann::json response;
  519. try {
  520. new_instance_json = nlohmann::json::parse(hyper_params_str);
  521. } catch (const std::exception &e) {
  522. response["result"] = false;
  523. response["message"] = "The hyper-parameter data is not in json format.";
  524. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  525. MS_LOG(ERROR) << "Sending response failed.";
  526. return;
  527. }
  528. }
  529. std::string result_message = "";
  530. bool result = iteration_->NewInstance(new_instance_json, &result_message);
  531. response["result"] = result;
  532. response["message"] = result_message;
  533. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  534. MS_LOG(ERROR) << "Sending response failed.";
  535. return;
  536. }
  537. }
  538. void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
  539. MS_ERROR_IF_NULL_WO_RET_VAL(message);
  540. nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> response;
  541. response["start_fl_job_threshold"] = ps::PSContext::instance()->start_fl_job_threshold();
  542. response["start_fl_job_time_window"] = ps::PSContext::instance()->start_fl_job_time_window();
  543. response["update_model_ratio"] = ps::PSContext::instance()->update_model_ratio();
  544. response["update_model_time_window"] = ps::PSContext::instance()->update_model_time_window();
  545. response["fl_iteration_num"] = ps::PSContext::instance()->fl_iteration_num();
  546. response["client_epoch_num"] = ps::PSContext::instance()->client_epoch_num();
  547. response["client_batch_size"] = ps::PSContext::instance()->client_batch_size();
  548. response["client_learning_rate"] = ps::PSContext::instance()->client_learning_rate();
  549. auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
  550. MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
  551. if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
  552. MS_LOG(ERROR) << "Sending response failed.";
  553. return;
  554. }
  555. }
  556. } // namespace server
  557. } // namespace fl
  558. } // namespace mindspore