/** * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fl/server/server.h" #include #include #include #ifdef ENABLE_ARMOUR #include "fl/armour/secure_protocol/secret_sharing.h" #endif #include "fl/server/round.h" #include "fl/server/model_store.h" #include "fl/server/iteration.h" #include "fl/server/collective_ops_impl.h" #include "fl/server/distributed_metadata_store.h" #include "fl/server/distributed_count_service.h" #include "fl/server/kernel/round/round_kernel_factory.h" namespace mindspore { namespace fl { namespace server { // The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster manager like K8S. std::shared_ptr g_communicator_with_server = nullptr; std::vector> g_communicators_with_worker = {}; void SignalHandler(int signal) { MS_LOG(WARNING) << "SIGTERM captured: " << signal; (void)std::for_each(g_communicators_with_worker.begin(), g_communicators_with_worker.end(), [](const std::shared_ptr &communicator) { MS_ERROR_IF_NULL_WO_RET_VAL(communicator); (void)communicator->Stop(); }); MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server); (void)g_communicator_with_server->Stop(); } void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector &rounds_config, const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) { MS_EXCEPTION_IF_NULL(func_graph); func_graph_ = func_graph; if (rounds_config.empty()) { MS_LOG(EXCEPTION) << "Rounds are empty."; return; } rounds_config_ = rounds_config; cipher_config_ = cipher_config; use_tcp_ = use_tcp; use_http_ = use_http; http_port_ = http_port; executor_threshold_ = executor_threshold; (void)signal(SIGTERM, SignalHandler); return; } void Server::Run() { std::unique_lock lock(scaling_mtx_); InitServerContext(); InitPkiCertificate(); InitCluster(); InitIteration(); RegisterCommCallbacks(); StartCommunicator(); InitExecutor(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); if (encrypt_type != ps::kNotEncryptType) { InitCipher(); MS_LOG(INFO) << "Parameters for secure aggregation have been initiated."; } RegisterRoundKernel(); InitMetrics(); Recover(); MS_LOG(INFO) << "Server started successfully."; safemode_ = false; lock.unlock(); // Wait communicators to stop so the main thread is blocked. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr &communicator) { MS_EXCEPTION_IF_NULL(communicator); communicator->Join(); }); MS_EXCEPTION_IF_NULL(communicator_with_server_); communicator_with_server_->Join(); MsException::Instance().CheckException(); } void Server::InitPkiCertificate() { if (ps::PSContext::instance()->pki_verify()) { root_first_ca_path_ = ps::PSContext::instance()->root_first_ca_path(); root_second_ca_path_ = ps::PSContext::instance()->root_second_ca_path(); equip_crl_path_ = ps::PSContext::instance()->equip_crl_path(); replay_attack_time_diff_ = ps::PSContext::instance()->replay_attack_time_diff(); bool ret = mindspore::ps::server::CertVerify::initRootCertAndCRL(root_first_ca_path_, root_second_ca_path_, equip_crl_path_, replay_attack_time_diff_); if (!ret) { MS_LOG(EXCEPTION) << "init root cert and crl failed."; return; } return; } } void Server::SwitchToSafeMode() { MS_LOG(INFO) << "Server switch to safemode."; safemode_ = true; } void Server::CancelSafeMode() { MS_LOG(INFO) << "Server cancel safemode."; safemode_ = false; } bool Server::IsSafeMode() const { return safemode_.load(); } void Server::WaitExitSafeMode() const { while (safemode_.load()) { std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime)); } } void Server::InitServerContext() { ps::PSContext::instance()->GenerateResetterRound(); scheduler_ip_ = ps::PSContext::instance()->scheduler_host(); scheduler_port_ = ps::PSContext::instance()->scheduler_port(); worker_num_ = ps::PSContext::instance()->initial_worker_num(); server_num_ = ps::PSContext::instance()->initial_server_num(); return; } void Server::InitCluster() { server_node_ = std::make_shared(); MS_EXCEPTION_IF_NULL(server_node_); task_executor_ = std::make_shared(kExecutorThreadPoolSize); MS_EXCEPTION_IF_NULL(task_executor_); if (!InitCommunicatorWithServer()) { MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed."; return; } if (!InitCommunicatorWithWorker()) { MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed."; return; } return; } bool Server::SubmitTask(std::function &&task) { if (task_executor_ == nullptr) { return false; } return task_executor_->Submit(task); } bool Server::InitCommunicatorWithServer() { MS_EXCEPTION_IF_NULL(task_executor_); MS_EXCEPTION_IF_NULL(server_node_); communicator_with_server_ = server_node_->GetOrCreateTcpComm(scheduler_ip_, static_cast(scheduler_port_), worker_num_, server_num_, task_executor_); MS_EXCEPTION_IF_NULL(communicator_with_server_); g_communicator_with_server = communicator_with_server_; return true; } bool Server::InitCommunicatorWithWorker() { MS_EXCEPTION_IF_NULL(server_node_); MS_EXCEPTION_IF_NULL(task_executor_); if (!use_tcp_ && !use_http_) { MS_LOG(EXCEPTION) << "At least one type of protocol should be set."; return false; } if (use_tcp_) { MS_EXCEPTION_IF_NULL(communicator_with_server_); auto tcp_comm = communicator_with_server_; MS_EXCEPTION_IF_NULL(tcp_comm); communicators_with_worker_.push_back(tcp_comm); } if (use_http_) { auto http_comm = server_node_->GetOrCreateHttpComm(server_node_->BoundIp(), http_port_, task_executor_); MS_EXCEPTION_IF_NULL(http_comm); communicators_with_worker_.push_back(http_comm); } g_communicators_with_worker = communicators_with_worker_; return true; } void Server::InitIteration() { iteration_ = &Iteration::GetInstance(); MS_EXCEPTION_IF_NULL(iteration_); // 1.Add rounds to the iteration according to the server mode. for (const RoundConfig &config : rounds_config_) { std::shared_ptr round = std::make_shared(config.name, config.check_timeout, config.time_window, config.check_count, config.threshold_count, config.server_num_as_threshold); MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout << ", time window: " << config.time_window << ", check_count: " << config.check_count << ", threshold: " << config.threshold_count << ", server_num_as_threshold: " << config.server_num_as_threshold; iteration_->AddRound(round); } #ifdef ENABLE_ARMOUR std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); if (encrypt_type == ps::kPWEncryptType) { cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold; cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold; cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold; cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold; cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold; cipher_push_list_sign_cnt_ = cipher_config_.push_list_sign_threshold; cipher_get_list_sign_cnt_ = cipher_config_.get_list_sign_threshold; minimum_clients_for_reconstruct = cipher_config_.minimum_clients_for_reconstruct; minimum_secret_shares_for_reconstruct = cipher_config_.minimum_clients_for_reconstruct - 1; cipher_time_window_ = cipher_config_.cipher_time_window; MS_LOG(INFO) << "Initializing cipher:"; MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_ << " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_ << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_; MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_ << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_ << " cipher_push_list_sign_cnt_: " << cipher_push_list_sign_cnt_ << " cipher_get_list_sign_cnt_: " << cipher_get_list_sign_cnt_ << " minimum_clients_for_reconstruct: " << minimum_clients_for_reconstruct << " minimum_secret_shares_for_reconstruct: " << minimum_secret_shares_for_reconstruct << " cipher_time_window_: " << cipher_time_window_; } #endif // 2.Initialize all the rounds. TimeOutCb time_out_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2); FinishIterCb finish_iter_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2); iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb); iteration_->InitGlobalIterTimer(time_out_cb); return; } void Server::InitCipher() { #ifdef ENABLE_ARMOUR cipher_init_ = &armour::CipherInit::GetInstance(); int cipher_t = SizeToInt(minimum_secret_shares_for_reconstruct); unsigned char cipher_p[SECRET_MAX_LEN] = {0}; const int cipher_g = 1; float dp_eps = ps::PSContext::instance()->dp_eps(); float dp_delta = ps::PSContext::instance()->dp_delta(); float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); float sign_k = ps::PSContext::instance()->sign_k(); float sign_eps = ps::PSContext::instance()->sign_eps(); float sign_thr_ratio = ps::PSContext::instance()->sign_thr_ratio(); float sign_global_lr = ps::PSContext::instance()->sign_global_lr(); int sign_dim_out = ps::PSContext::instance()->sign_dim_out(); mindspore::armour::CipherPublicPara param; param.g = cipher_g; param.t = cipher_t; int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p)); if (ret != 0) { MS_LOG(EXCEPTION) << "Memcpy_s error, errorno" << ret; } param.dp_delta = dp_delta; param.dp_eps = dp_eps; param.dp_norm_clip = dp_norm_clip; param.encrypt_type = encrypt_type; param.sign_k = sign_k; param.sign_eps = sign_eps; param.sign_thr_ratio = sign_thr_ratio; param.sign_global_lr = sign_global_lr; param.sign_dim_out = sign_dim_out; BIGNUM *prim = BN_new(); if (prim == NULL) { MS_LOG(EXCEPTION) << "new bn failed."; ret = -1; } else { ret = mindspore::armour::GetPrime(prim); } if (ret == 0) { (void)BN_bn2bin(prim, reinterpret_cast(param.prime)); } else { MS_LOG(EXCEPTION) << "Get prime failed."; } if (prim != NULL) { BN_clear_free(prim); } if (!cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_, cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_push_list_sign_cnt_, cipher_get_list_sign_cnt_, minimum_clients_for_reconstruct)) { MS_LOG(EXCEPTION) << "cipher init fail."; } #endif } void Server::RegisterCommCallbacks() { // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register // rounds' callbacks. MS_EXCEPTION_IF_NULL(server_node_); MS_EXCEPTION_IF_NULL(iteration_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_EXCEPTION_IF_NULL(tcp_comm); // Set message callbacks for server-to-server communication. DistributedMetadataStore::GetInstance().RegisterMessageCallback(tcp_comm); DistributedCountService::GetInstance().RegisterMessageCallback(tcp_comm); iteration_->RegisterMessageCallback(tcp_comm); iteration_->RegisterEventCallback(server_node_); // Set exception event callbacks for server. RegisterExceptionEventCallback(tcp_comm); // Set message callbacks for server. RegisterMessageCallback(tcp_comm); if (!server_node_->InitFollowerScaler()) { MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed."; return; } // Set scaling barriers before scaling. server_node_->RegisterFollowerScalerBarrierBeforeScaleOut("ServerPipeline", std::bind(&Server::ProcessBeforeScalingOut, this)); server_node_->RegisterFollowerScalerBarrierBeforeScaleIn("ServerPipeline", std::bind(&Server::ProcessBeforeScalingIn, this)); // Set handlers after scheduler scaling operations are done. server_node_->RegisterFollowerScalerHandlerAfterScaleOut("ServerPipeline", std::bind(&Server::ProcessAfterScalingOut, this)); server_node_->RegisterFollowerScalerHandlerAfterScaleIn("ServerPipeline", std::bind(&Server::ProcessAfterScalingIn, this)); } void Server::RegisterExceptionEventCallback(const std::shared_ptr &communicator) { MS_EXCEPTION_IF_NULL(communicator); communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() { MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; safemode_ = true; (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr &communicator) { MS_ERROR_IF_NULL_WO_RET_VAL(communicator); (void)communicator->Stop(); }); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); (void)communicator_with_server_->Stop(); }); communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() { MS_LOG(ERROR) << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the " "network building phase."; safemode_ = true; (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr &communicator) { MS_ERROR_IF_NULL_WO_RET_VAL(communicator); (void)communicator->Stop(); }); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); (void)communicator_with_server_->Stop(); }); } void Server::RegisterMessageCallback(const std::shared_ptr &communicator) { MS_EXCEPTION_IF_NULL(communicator); // Register handler for restful requests receviced by scheduler. communicator->RegisterMsgCallBack("enableFLS", std::bind(&Server::HandleEnableServerRequest, this, std::placeholders::_1)); communicator->RegisterMsgCallBack("disableFLS", std::bind(&Server::HandleDisableServerRequest, this, std::placeholders::_1)); communicator->RegisterMsgCallBack("newInstance", std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1)); communicator->RegisterMsgCallBack("queryInstance", std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1)); communicator->RegisterMsgCallBack("syncAfterRecover", std::bind(&Server::HandleSyncAfterRecoveryRequest, this, std::placeholders::_1)); } void Server::InitExecutor() { if (executor_threshold_ == 0) { MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0."; return; } auto func_graph = func_graph_.lock(); MS_EXCEPTION_IF_NULL(func_graph); // The train engine instance is used in both push-type and pull-type kernels, // so the required_cnt of these kernels must be the same as executor_threshold_. MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_; Executor::GetInstance().Initialize(func_graph, executor_threshold_); ModelStore::GetInstance().Initialize(); // init weight memory to 0 after get model Executor::GetInstance().ResetAggregationStatus(); return; } void Server::RegisterRoundKernel() { MS_EXCEPTION_IF_NULL(iteration_); auto &rounds = iteration_->rounds(); if (rounds.empty()) { MS_LOG(EXCEPTION) << "Server has no round registered."; return; } for (auto &round : rounds) { MS_EXCEPTION_IF_NULL(round); const std::string &name = round->name(); std::shared_ptr round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name); if (round_kernel == nullptr) { MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered."; return; } // For some round kernels, the threshold count should be set. round_kernel->InitKernel(round->threshold_count()); round->BindRoundKernel(round_kernel); } return; } void Server::InitMetrics() { if (server_node_->rank_id() == kLeaderServerRank) { MS_EXCEPTION_IF_NULL(iteration_); std::shared_ptr iteration_metrics = std::make_shared(ps::PSContext::instance()->config_file_path()); if (!iteration_metrics->Initialize()) { MS_LOG(WARNING) << "Initializing metrics failed."; return; } iteration_->set_metrics(iteration_metrics); } } void Server::StartCommunicator() { if (communicators_with_worker_.empty()) { MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty."; return; } MS_EXCEPTION_IF_NULL(server_node_); MS_EXCEPTION_IF_NULL(communicator_with_server_); MS_LOG(INFO) << "Start communicator with server."; if (!communicator_with_server_->Start()) { MS_LOG(EXCEPTION) << "Starting communicator with server failed."; } DistributedMetadataStore::GetInstance().Initialize(server_node_); CollectiveOpsImpl::GetInstance().Initialize(server_node_); DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank); MS_LOG(INFO) << "This server rank is " << server_node_->rank_id(); MS_LOG(INFO) << "Start communicator with worker."; (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr &communicator) { MS_ERROR_IF_NULL_WO_RET_VAL(communicator); if (!communicator->Start()) { MS_LOG(EXCEPTION) << "Starting communicator with worker failed."; } }); } void Server::Recover() { server_recovery_ = std::make_shared(); MS_EXCEPTION_IF_NULL(server_recovery_); // Try to recovery from persistent storage. if (!server_recovery_->Initialize(ps::PSContext::instance()->config_file_path())) { MS_LOG(WARNING) << "Initializing server recovery failed. Do not recover for this server."; return; } if (server_recovery_->Recover()) { // If this server recovers, need to notify cluster to reach consistency. auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); MS_LOG(INFO) << "Synchronize with leader server after recovery."; if (!server_recovery_->SyncAfterRecovery(tcp_comm, server_node_->rank_id())) { MS_LOG(EXCEPTION) << "Failed to reach consistency of the cluster after recovery."; return; } if (server_node_->rank_id() == kLeaderServerRank) { MS_EXCEPTION_IF_NULL(iteration_); iteration_->NotifyNext(false, "Move to next iteration after server 0 recovery."); } } // Set the recovery handler to Iteration. MS_EXCEPTION_IF_NULL(iteration_); iteration_->set_recovery_handler(server_recovery_); } void Server::ProcessBeforeScalingOut() { MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); iteration_->ScalingBarrier(); safemode_ = true; } void Server::ProcessBeforeScalingIn() { MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); iteration_->ScalingBarrier(); safemode_ = true; } void Server::ProcessAfterScalingOut() { std::unique_lock lock(scaling_mtx_); MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; } if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; } if (!DistributedCountService::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "DistributedCountService reinitializing failed."; } if (!iteration_->ReInitForScaling(server_node_->server_num(), server_node_->rank_id())) { MS_LOG(WARNING) << "Iteration reinitializing failed."; } if (!Executor::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "Executor reinitializing failed."; } std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); safemode_ = false; } void Server::ProcessAfterScalingIn() { std::unique_lock lock(scaling_mtx_); MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); if (server_node_->rank_id() == UINT32_MAX) { MS_LOG(WARNING) << "This server the one to be scaled in. Server need to wait SIGTERM to exit."; return; } // If the server is not the one to be scaled in, reintialize modules and recover service. if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; } if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; } if (!DistributedCountService::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "DistributedCountService reinitializing failed."; } if (!iteration_->ReInitForScaling(server_node_->server_num(), server_node_->rank_id())) { MS_LOG(WARNING) << "Iteration reinitializing failed."; } if (!Executor::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "Executor reinitializing failed."; } std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); safemode_ = false; } void Server::HandleEnableServerRequest(const std::shared_ptr &message) { MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); std::string result_message = ""; bool result = iteration_->EnableServerInstance(&result_message); nlohmann::json response; response["result"] = result; response["message"] = result_message; if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) { MS_LOG(ERROR) << "Sending response failed."; return; } } void Server::HandleDisableServerRequest(const std::shared_ptr &message) { MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); std::string result_message = ""; bool result = iteration_->DisableServerInstance(&result_message); nlohmann::json response; response["result"] = result; response["message"] = result_message; if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) { MS_LOG(ERROR) << "Sending response failed."; return; } } void Server::HandleNewInstanceRequest(const std::shared_ptr &message) { MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); std::string hyper_params_str(static_cast(message->data()), message->len()); nlohmann::json new_instance_json; nlohmann::json response; try { new_instance_json = nlohmann::json::parse(hyper_params_str); } catch (const std::exception &e) { response["result"] = false; response["message"] = "The hyper-parameter data is not in json format."; if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) { MS_LOG(ERROR) << "Sending response failed."; return; } } std::string result_message = ""; bool result = iteration_->NewInstance(new_instance_json, &result_message); response["result"] = result; response["message"] = result_message; if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) { MS_LOG(ERROR) << "Sending response failed."; return; } } void Server::HandleQueryInstanceRequest(const std::shared_ptr &message) { MS_ERROR_IF_NULL_WO_RET_VAL(message); nlohmann::basic_json response; response["start_fl_job_threshold"] = ps::PSContext::instance()->start_fl_job_threshold(); response["start_fl_job_time_window"] = ps::PSContext::instance()->start_fl_job_time_window(); response["update_model_ratio"] = ps::PSContext::instance()->update_model_ratio(); response["update_model_time_window"] = ps::PSContext::instance()->update_model_time_window(); response["fl_iteration_num"] = ps::PSContext::instance()->fl_iteration_num(); response["client_epoch_num"] = ps::PSContext::instance()->client_epoch_num(); response["client_batch_size"] = ps::PSContext::instance()->client_batch_size(); response["client_learning_rate"] = ps::PSContext::instance()->client_learning_rate(); response["global_iteration_time_window"] = ps::PSContext::instance()->global_iteration_time_window(); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) { MS_LOG(ERROR) << "Sending response failed."; return; } } void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr &message) { MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); MS_LOG(INFO) << "Receive SyncAfterRecover request from other server."; std::string response = "success"; if (!tcp_comm->SendResponse(response.c_str(), response.size(), message)) { MS_LOG(ERROR) << "Sending response of SyncAfterRecoverRequest failed."; return; } if (!safemode_.load()) { MS_LOG(INFO) << "Need to synchronize for other server's recovery"; SyncAfterRecover sync_after_recovery_req; (void)sync_after_recovery_req.ParseFromArray(message->data(), SizeToInt(message->len())); if (!iteration_->SyncAfterRecovery(sync_after_recovery_req.current_iter_num())) { MS_LOG(ERROR) << "Sync after recovery failed."; return; } } } } // namespace server } // namespace fl } // namespace mindspore