| @@ -293,6 +293,7 @@ inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string & | |||
| // Definitions for Federated Learning. | |||
| constexpr auto kNetworkError = "Cluster networking failed."; | |||
| constexpr auto KTriggerCounterEventError = "Cluster trigger counter event failed."; | |||
| // The result code used for round kernels. | |||
| enum class ResultCode { | |||
| @@ -86,7 +86,7 @@ bool DistributedCountService::ReInitCounter(const std::string &name, size_t glob | |||
| return true; | |||
| } | |||
| bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) { | |||
| bool DistributedCountService::Count(const std::string &name, const std::string &id) { | |||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; | |||
| if (local_rank_ == counting_server_rank_) { | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| @@ -107,9 +107,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| MS_LOG(INFO) << "Global current count for " << name << " is: " << global_current_count_[name].size() << "/" | |||
| << global_threshold_count_[name]; | |||
| } | |||
| if (!TriggerCounterEvent(name, reason)) { | |||
| if (!TriggerCounterEvent(name)) { | |||
| MS_LOG(WARNING) << "Leader server trigger count event failed."; | |||
| Iteration::GetInstance().NotifyNext(false, *reason); | |||
| Iteration::GetInstance().NotifyNext(false, KTriggerCounterEventError); | |||
| return false; | |||
| } | |||
| } else { | |||
| @@ -121,10 +121,8 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount, | |||
| &report_cnt_rsp_msg)) { | |||
| MS_LOG(WARNING) << "Sending reporting count message to leader server failed for " << name; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| MS_LOG(WARNING) << "Sending reporting count " + name + " message to leader server failed for fl id " << id; | |||
| Iteration::GetInstance().NotifyNext(false, kNetworkError); | |||
| return false; | |||
| } | |||
| @@ -133,10 +131,6 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); | |||
| if (!count_rsp.result()) { | |||
| MS_LOG(WARNING) << "Reporting count failed:" << count_rsp.reason(); | |||
| // If the error is caused by the network issue, return the reason. | |||
| if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { | |||
| *reason = kNetworkError; | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| @@ -263,9 +257,8 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core: | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| std::string reason = "success"; | |||
| if (!TriggerCounterEvent(name, &reason)) { | |||
| Iteration::GetInstance().NotifyNext(false, reason); | |||
| if (!TriggerCounterEvent(name)) { | |||
| Iteration::GetInstance().NotifyNext(false, KTriggerCounterEventError); | |||
| } | |||
| } | |||
| @@ -322,7 +315,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core: | |||
| return; | |||
| } | |||
| bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) { | |||
| bool DistributedCountService::TriggerCounterEvent(const std::string &name) { | |||
| if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(WARNING) << "The counter of " << name << " is not registered."; | |||
| return false; | |||
| @@ -332,19 +325,19 @@ bool DistributedCountService::TriggerCounterEvent(const std::string &name, std:: | |||
| << ", threshold count is " << global_threshold_count_[name]; | |||
| // The threshold count may be 1 so the first and last count event should be both activated. | |||
| if (global_current_count_[name].size() == 1) { | |||
| if (!TriggerFirstCountEvent(name, reason)) { | |||
| if (!TriggerFirstCountEvent(name)) { | |||
| return false; | |||
| } | |||
| } | |||
| if (global_current_count_[name].size() == global_threshold_count_[name]) { | |||
| if (!TriggerLastCountEvent(name, reason)) { | |||
| if (!TriggerLastCountEvent(name)) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, std::string *reason) { | |||
| bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) { | |||
| MS_LOG(DEBUG) << "Activating first count event for " << name; | |||
| CounterEvent first_count_event; | |||
| first_count_event.set_type(CounterEventType::FIRST_CNT); | |||
| @@ -354,10 +347,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st | |||
| for (uint32_t i = 1; i < server_num_; i++) { | |||
| MS_LOG(DEBUG) << "Start sending first count event message to server " << i; | |||
| if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | |||
| MS_LOG(WARNING) << "Activating first count event to server " << i << " failed."; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| MS_LOG(WARNING) << "Send activating first count event to server " << i << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -374,7 +364,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st | |||
| return true; | |||
| } | |||
| bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) { | |||
| bool DistributedCountService::TriggerLastCountEvent(const std::string &name) { | |||
| MS_LOG(DEBUG) << "Activating last count event for " << name; | |||
| CounterEvent last_count_event; | |||
| last_count_event.set_type(CounterEventType::LAST_CNT); | |||
| @@ -384,10 +374,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std | |||
| for (uint32_t i = 1; i < server_num_; i++) { | |||
| MS_LOG(DEBUG) << "Start sending last count event message to server " << i; | |||
| if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | |||
| MS_LOG(WARNING) << "Activating last count event to server " << i << " failed."; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| MS_LOG(WARNING) << "Send activating last count event to server " << i << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -67,9 +67,8 @@ class DistributedCountService { | |||
| // Reinitialize counter due to the change of threshold count. | |||
| bool ReInitCounter(const std::string &name, size_t global_threshold_count); | |||
| // Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the | |||
| // reason why counting failed. | |||
| bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr); | |||
| // Report a count to the counting server. Parameter 'id' is in case of repeated counting. | |||
| bool Count(const std::string &name, const std::string &id); | |||
| // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, | |||
| // this method returns true. | |||
| @@ -103,9 +102,9 @@ class DistributedCountService { | |||
| void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message); | |||
| // Call the callbacks when the first/last count event is triggered. | |||
| bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr); | |||
| bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr); | |||
| bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr); | |||
| bool TriggerCounterEvent(const std::string &name); | |||
| bool TriggerFirstCountEvent(const std::string &name); | |||
| bool TriggerLastCountEvent(const std::string &name); | |||
| // Members for the communication between counting server and other servers. | |||
| std::shared_ptr<ps::core::ServerNode> server_node_; | |||
| @@ -86,7 +86,7 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) { | |||
| return; | |||
| } | |||
| bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) { | |||
| bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; | |||
| return false; | |||
| @@ -107,10 +107,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM | |||
| if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata, | |||
| &update_meta_rsp_msg)) { | |||
| MS_LOG(WARNING) << "Sending updating metadata message to server " << stored_rank << " failed."; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| Iteration::GetInstance().NotifyNext(false, *reason); | |||
| Iteration::GetInstance().NotifyNext(false, kNetworkError); | |||
| return false; | |||
| } | |||
| @@ -55,8 +55,8 @@ class DistributedMetadataStore { | |||
| // Reset the metadata value for the name. | |||
| void ResetMetadata(const std::string &name); | |||
| // Update the metadata for the name. Parameter 'reason' is the reason why updating meta data failed. | |||
| bool UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason = nullptr); | |||
| // Update the metadata for the name. | |||
| bool UpdateMetadata(const std::string &name, const PBMetadata &meta); | |||
| // Get the metadata for the name. | |||
| PBMetadata GetMetadata(const std::string &name); | |||
| @@ -153,9 +153,8 @@ bool GetListSignKernel::Launch(const uint8_t *req_data, size_t len, | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { | |||
| std::string reason = "Counting for get list sign request failed. Please retry later. " + count_reason; | |||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id)) { | |||
| std::string reason = "Counting for get list sign request failed for fl id " + fl_id + ". Please retry later. "; | |||
| BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), | |||
| iter_num, list_signs); | |||
| MS_LOG(ERROR) << reason; | |||
| @@ -128,9 +128,8 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { | |||
| std::string reason = "Counting for push list sign request failed. Please retry later. " + count_reason; | |||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id)) { | |||
| std::string reason = "Counting for push list sign request failed for fl id " + fl_id + ". Please retry later."; | |||
| BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), | |||
| iter_num); | |||
| MS_LOG(ERROR) << reason; | |||
| @@ -85,8 +85,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb, | |||
| Iteration::GetInstance().set_loss(loss); | |||
| Iteration::GetInstance().set_accuracy(accuracy); | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) { | |||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { | |||
| std::string reason = "Count for push metrics request failed."; | |||
| BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError); | |||
| MS_LOG(ERROR) << reason; | |||
| @@ -110,8 +110,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| } | |||
| MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds."; | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) { | |||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { | |||
| std::string reason = "Count for push weight request failed."; | |||
| BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); | |||
| MS_LOG(ERROR) << reason; | |||
| @@ -107,8 +107,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| } | |||
| PBMetadata metadata; | |||
| *metadata.mutable_device_meta() = device_meta; | |||
| std::string update_reason = ""; | |||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) { | |||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) { | |||
| std::string reason = "Updating device metadata failed for fl id " + device_meta.fl_id(); | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| @@ -116,7 +115,6 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return false; | |||
| } | |||
| // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. | |||
| result_code = CountForStartFLJob(fbb, start_fl_job_req); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| @@ -299,8 +297,7 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail); | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) { | |||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { | |||
| std::string reason = | |||
| "Counting start fl job request failed for fl id " + start_fl_job_req->fl_id()->str() + ", Please retry later."; | |||
| BuildStartFLJobRsp( | |||
| @@ -365,8 +365,7 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| fl_id.set_fl_id(update_model_fl_id); | |||
| PBMetadata comm_value; | |||
| *comm_value.mutable_fl_id() = fl_id; | |||
| std::string update_reason = ""; | |||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) { | |||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) { | |||
| std::string reason = "Updating metadata of UpdateModelClientList failed for fl id " + update_model_fl_id; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| @@ -526,9 +525,8 @@ std::map<std::string, UploadData> UpdateModelKernel::DecodeFeatureMap( | |||
| } | |||
| ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) { | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) { | |||
| MS_LOG(ERROR) << "Counting for aggregation failed. reason: " + count_reason; | |||
| if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id)) { | |||
| MS_LOG(ERROR) << "Counting for aggregation failed for fl id " << req_fl_id; | |||
| return ResultCode::kFail; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| @@ -538,10 +536,9 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilde | |||
| const schema::RequestUpdateModel *update_model_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) { | |||
| if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) { | |||
| std::string reason = "Counting for update model request failed for fl id " + update_model_req->fl_id()->str() + | |||
| ", Please retry later. " + count_reason; | |||
| ", Please retry later."; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| @@ -142,6 +142,7 @@ constexpr char kRecoveryTotalNodeNum[] = "total_node_num"; | |||
| constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id"; | |||
| constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id"; | |||
| constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids"; | |||
| constexpr char kRecoveryClusterState[] = "cluster_state"; | |||
| constexpr char kServerCertPath[] = "server_cert_path"; | |||
| constexpr char kServerPassword[] = "server_password"; | |||
| @@ -46,7 +46,8 @@ struct ClusterConfig { | |||
| scheduler_timeout(30), | |||
| initial_total_node_num(0), | |||
| initial_next_worker_rank_id(0), | |||
| initial_next_server_rank_id(0) {} | |||
| initial_next_server_rank_id(0), | |||
| initial_cluster_state(ClusterState::CLUSTER_STARTING) {} | |||
| // Configure through environment variables:MS_WORKER_NUM | |||
| uint32_t initial_worker_num; | |||
| // Configure through environment variables:MS_SERVER_NUM | |||
| @@ -72,6 +73,7 @@ struct ClusterConfig { | |||
| uint32_t initial_total_node_num; | |||
| uint32_t initial_next_worker_rank_id; | |||
| uint32_t initial_next_server_rank_id; | |||
| ClusterState initial_cluster_state; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -58,6 +58,7 @@ | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <algorithm> | |||
| #include "proto/comm.pb.h" | |||
| @@ -99,6 +100,19 @@ const std::vector<std::string> kClusterState = { | |||
| "CLUSTER_SCALE_OUT_ROLLBACK", // When the cluster is scale out rollback. | |||
| }; | |||
| const std::map<std::string, ClusterState> kClusterStateMap = { | |||
| {"CLUSTER_STARTING", ClusterState::CLUSTER_STARTING}, | |||
| {"CLUSTER_READY", ClusterState::CLUSTER_READY}, | |||
| {"CLUSTER_EXIT", ClusterState::CLUSTER_EXIT}, | |||
| {"NODE_TIMEOUT", ClusterState::NODE_TIMEOUT}, | |||
| {"CLUSTER_SCALE_OUT", ClusterState::CLUSTER_SCALE_OUT}, | |||
| {"CLUSTER_SCALE_IN", ClusterState::CLUSTER_SCALE_IN}, | |||
| {"CLUSTER_NEW_INSTANCE", ClusterState::CLUSTER_NEW_INSTANCE}, | |||
| {"CLUSTER_ENABLE_FLS", ClusterState::CLUSTER_ENABLE_FLS}, | |||
| {"CLUSTER_DISABLE_FLS", ClusterState::CLUSTER_DISABLE_FLS}, | |||
| {"CLUSTER_SCHEDULER_RECOVERY", ClusterState::CLUSTER_SCHEDULER_RECOVERY}, | |||
| {"CLUSTER_SCALE_OUT_ROLLBACK", ClusterState::CLUSTER_SCALE_OUT_ROLLBACK}}; | |||
| class CommUtil { | |||
| public: | |||
| static bool CheckIpWithRegex(const std::string &ip); | |||
| @@ -117,7 +117,6 @@ void FileConfiguration::PersistNodes(const core::ClusterConfig &clusterConfig) c | |||
| res["node_id"] = node_info.node_id_; | |||
| res["rank_id"] = std::to_string(node_info.rank_id_); | |||
| res["role"] = CommUtil::NodeRoleToString(node_info.node_role_); | |||
| res["alive"] = CommUtil::BoolToString(node_info.is_alive); | |||
| persist_js["node_ids"].push_back(res); | |||
| } | |||
| @@ -138,6 +137,7 @@ void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) co | |||
| persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num; | |||
| persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host; | |||
| persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port; | |||
| persist_js[kRecoveryClusterState] = CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state); | |||
| std::ofstream output_file(file_path_); | |||
| output_file << persist_js.dump(); | |||
| @@ -225,6 +225,8 @@ std::string FollowerScaler::GetNodeScaleStateStr() { | |||
| return "kWaiting"; | |||
| case NodeScaleState::kScaling: | |||
| return "kScaling"; | |||
| case NodeScaleState::kRollback: | |||
| return "kRollback"; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "scale_state is not supported."; | |||
| } | |||
| @@ -35,7 +35,7 @@ void InstanceManager::NewInstanceAsync(const std::shared_ptr<TcpClient> &client, | |||
| MS_LOG(WARNING) << "Send new instance timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The scheduler is sending new instance to workers and servers!"; | |||
| MS_LOG(INFO) << "The scheduler is sending new instance to " << node_info.node_id_; | |||
| } | |||
| void InstanceManager::QueryInstanceAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | |||
| @@ -55,7 +55,7 @@ void InstanceManager::QueryInstanceAsync(const std::shared_ptr<TcpClient> &clien | |||
| MS_LOG(WARNING) << "Send query instance timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; | |||
| MS_LOG(INFO) << "The scheduler is sending query instance to " << node_info.node_id_; | |||
| } | |||
| void InstanceManager::EnableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | |||
| @@ -75,7 +75,7 @@ void InstanceManager::EnableFLSAsync(const std::shared_ptr<TcpClient> &client, c | |||
| MS_LOG(WARNING) << "Send query instance timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; | |||
| MS_LOG(INFO) << "The scheduler is sending enable FLS to " << node_info.node_id_; | |||
| } | |||
| void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | |||
| @@ -95,7 +95,7 @@ void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client, | |||
| MS_LOG(WARNING) << "Send query instance timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; | |||
| MS_LOG(INFO) << "The scheduler is sending disable FLS to " << node_info.node_id_; | |||
| } | |||
| void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &client, const NodeManager &, | |||
| @@ -115,7 +115,7 @@ void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &clie | |||
| MS_LOG(WARNING) << "Send query node scale state timeout!"; | |||
| } | |||
| MS_LOG(INFO) << "The scheduler is sending query node scale state to workers and servers!"; | |||
| MS_LOG(INFO) << "The scheduler is sending query node scale state to " << node_info.node_id_; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -54,7 +54,6 @@ class BACKEND_EXPORT Node { | |||
| is_already_stopped_(true), | |||
| is_already_finished_(false), | |||
| next_request_id_(0), | |||
| current_node_state_(NodeState::NODE_STARTING), | |||
| current_cluster_state_(ClusterState::CLUSTER_STARTING) {} | |||
| virtual ~Node() = default; | |||
| @@ -113,8 +112,6 @@ class BACKEND_EXPORT Node { | |||
| std::mutex message_tracker_mutex_; | |||
| std::condition_variable message_tracker_cond_; | |||
| // Worker and server receive the node state and cluster state from the scheduler. | |||
| NodeState current_node_state_; | |||
| ClusterState current_cluster_state_; | |||
| // Configuration file,The format is as follows | |||
| @@ -70,8 +70,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message | |||
| registered_nodes_info_[node_id] = recovery_node_infos[node_id]; | |||
| MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!" | |||
| << ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_ | |||
| << ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive | |||
| << ", fl iteration num: " << new_fl_iteration_num | |||
| << ", rank id: " << rank_id << ", fl iteration num: " << new_fl_iteration_num | |||
| << ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_); | |||
| return rank_id; | |||
| } | |||
| @@ -235,12 +234,15 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) { | |||
| timeout_nodes_info_[it->first] = registered_nodes_info_[it->first]; | |||
| registered_nodes_info_[it->first].is_alive = false; | |||
| } | |||
| } else { | |||
| if (registered_nodes_info_.count(it->first) && !registered_nodes_info_[it->first].is_alive) { | |||
| MS_LOG(WARNING) << registered_nodes_info_[it->first].node_id_ << " is alive."; | |||
| registered_nodes_info_[it->first].is_alive = true; | |||
| } | |||
| } | |||
| } | |||
| if (!timeout_nodes_info_.empty()) { | |||
| UpdateClusterState(ClusterState::NODE_TIMEOUT); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) { | |||
| @@ -249,25 +251,14 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) { | |||
| finish_nodes_id_.insert(iter->first); | |||
| } | |||
| } | |||
| if (onPersist_) { | |||
| onPersist_(); | |||
| if (cluster_state_ != ClusterState::CLUSTER_DISABLE_FLS) { | |||
| UpdateClusterState(ClusterState::NODE_TIMEOUT); | |||
| } | |||
| } else if (SizeToUint(heartbeats_.size()) == total_node_num_) { | |||
| if (cluster_state_ == ClusterState::NODE_TIMEOUT) { | |||
| for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) { | |||
| if (registered_nodes_info_.count(it->first) && !it->second.is_alive) { | |||
| MS_LOG(WARNING) << it->second.node_id_ << " is alive."; | |||
| it->second.is_alive = true; | |||
| } | |||
| } | |||
| if (onPersist_) { | |||
| onPersist_(); | |||
| } | |||
| if (is_cluster_ready) { | |||
| UpdateClusterState(ClusterState::CLUSTER_READY); | |||
| } else { | |||
| UpdateClusterState(ClusterState::CLUSTER_STARTING); | |||
| } | |||
| } else if (SizeToUint(heartbeats_.size()) == total_node_num_ && cluster_state_ == ClusterState::NODE_TIMEOUT) { | |||
| if (is_cluster_ready) { | |||
| UpdateClusterState(ClusterState::CLUSTER_READY); | |||
| } else { | |||
| UpdateClusterState(ClusterState::CLUSTER_STARTING); | |||
| } | |||
| } | |||
| @@ -324,11 +315,6 @@ void NodeManager::UpdateNodesInfo() { | |||
| nodes_info_ = registered_nodes_info_; | |||
| } | |||
| void NodeManager::UpdateNodeState(const NodeState &state) { | |||
| std::lock_guard<std::mutex> lk(node_mutex_); | |||
| node_state_ = state; | |||
| } | |||
| void NodeManager::UpdateClusterState(const ClusterState &state) { | |||
| std::lock_guard<std::mutex> lk(cluster_mutex_); | |||
| std::string state_str = CommUtil::ClusterStateToString(state); | |||
| @@ -340,11 +326,6 @@ void NodeManager::UpdateClusterState(const ClusterState &state) { | |||
| cluster_state_ = state; | |||
| } | |||
| NodeState NodeManager::GetNodeState() { | |||
| std::lock_guard<std::mutex> lk(node_mutex_); | |||
| return node_state_; | |||
| } | |||
| ClusterState NodeManager::GetClusterState() { | |||
| std::lock_guard<std::mutex> lk(cluster_mutex_); | |||
| return cluster_state_; | |||
| @@ -49,7 +49,6 @@ class NodeManager { | |||
| next_worker_rank_id_(0), | |||
| next_server_rank_id_(0), | |||
| meta_data_(nullptr), | |||
| node_state_(NodeState::NODE_STARTING), | |||
| cluster_state_(ClusterState::CLUSTER_STARTING) {} | |||
| virtual ~NodeManager() = default; | |||
| using OnPersist = std::function<void()>; | |||
| @@ -104,9 +103,7 @@ class NodeManager { | |||
| uint32_t next_worker_rank_id() const; | |||
| uint32_t next_server_rank_id() const; | |||
| void UpdateNodeState(const NodeState &state); | |||
| void UpdateClusterState(const ClusterState &state); | |||
| NodeState GetNodeState(); | |||
| ClusterState GetClusterState(); | |||
| // When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes | |||
| @@ -179,7 +176,6 @@ class NodeManager { | |||
| // Cluster metadata information can be dynamically changed | |||
| std::unique_ptr<ClusterMetadata> meta_data_; | |||
| NodeState node_state_; | |||
| ClusterState cluster_state_; | |||
| std::deque<uint32_t> recovery_worker_rank_id_; | |||
| @@ -57,7 +57,6 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| MS_LOG(ERROR) << "Start Scheduler node timeout!"; | |||
| return false; | |||
| } | |||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | |||
| StartUpdatePersistentCommandTimer(); | |||
| MS_LOG(INFO) << "[Scheduler start]: 4. Successfully start scheduler, there are " << node_manager_.worker_num() | |||
| @@ -85,7 +84,26 @@ void SchedulerNode::RunRecovery() { | |||
| MS_LOG(WARNING) << "There is no registered nodes in scheduler!"; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "The scheduler start run recovery!"; | |||
| MS_LOG(INFO) << "The scheduler start run recovery!" | |||
| << " The worker num:" << clusterConfig.initial_worker_num | |||
| << ", the server num:" << clusterConfig.initial_server_num | |||
| << ", the scheduler ip:" << clusterConfig.scheduler_host | |||
| << ", the scheduler port:" << clusterConfig.scheduler_port | |||
| << ", the initial total node num:" << clusterConfig.initial_total_node_num | |||
| << ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id | |||
| << ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id | |||
| << ", the initial cluster state:" << kClusterState.at(clusterConfig.initial_cluster_state); | |||
| if (!clusterConfig.initial_registered_nodes_infos.empty()) { | |||
| for (const auto kvs : clusterConfig.initial_registered_nodes_infos) { | |||
| MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_ | |||
| << ", the node_id:" << kvs.second.node_id_ | |||
| << ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_) | |||
| << ", the rank_id_:" << kvs.second.rank_id_ | |||
| << ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive); | |||
| } | |||
| } | |||
| uint32_t worker_num = clusterConfig.initial_worker_num; | |||
| uint32_t server_num = clusterConfig.initial_server_num; | |||
| @@ -94,6 +112,11 @@ void SchedulerNode::RunRecovery() { | |||
| node_manager_.set_next_worker_rank_id(clusterConfig.initial_next_worker_rank_id); | |||
| node_manager_.set_next_server_rank_id(clusterConfig.initial_next_server_rank_id); | |||
| node_manager_.set_total_node_num(clusterConfig.initial_total_node_num); | |||
| if (clusterConfig.initial_cluster_state == ClusterState::CLUSTER_DISABLE_FLS) { | |||
| MS_LOG(WARNING) << "Scheduler recover and update cluster state from recovery file, cluster state is " | |||
| << CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state); | |||
| node_manager_.UpdateClusterState(clusterConfig.initial_cluster_state); | |||
| } | |||
| for (const auto &kvs : initial_node_infos) { | |||
| auto &node_id = kvs.first; | |||
| @@ -352,44 +375,56 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server, | |||
| "will exit later."; | |||
| return; | |||
| } | |||
| if (!BuildingNetwork()) { | |||
| MS_LOG(ERROR) << "Building network failed! Cluster will exit later."; | |||
| } | |||
| } | |||
| } | |||
| if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { | |||
| auto nodes = node_manager_.nodes_info(); | |||
| for (const auto &id : scale_in_node_ids_) { | |||
| MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id; | |||
| if (nodes.count(id)) { | |||
| auto scale_in_client = GetOrCreateClient(nodes[id]); | |||
| SendMetadata(scale_in_client, nodes[id].rank_id_); | |||
| node_manager_.UpdateHeartbeat(id); | |||
| } | |||
| if (connected_nodes_.count(id)) { | |||
| MS_LOG(INFO) << "remove scale in node id: " << id << " connection."; | |||
| connected_nodes_.erase(id); | |||
| } | |||
| bool SchedulerNode::BuildingNetwork() { | |||
| if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { | |||
| auto nodes = node_manager_.nodes_info(); | |||
| for (const auto &id : scale_in_node_ids_) { | |||
| MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id; | |||
| if (nodes.count(id)) { | |||
| auto scale_in_client = GetOrCreateClient(nodes[id]); | |||
| SendMetadata(scale_in_client, nodes[id].rank_id_); | |||
| node_manager_.UpdateHeartbeat(id); | |||
| } | |||
| if (connected_nodes_.count(id)) { | |||
| MS_LOG(INFO) << "remove scale in node id: " << id << " connection."; | |||
| connected_nodes_.erase(id); | |||
| } | |||
| } | |||
| node_manager_.UpdateNodesInfo(); | |||
| auto node_infos = node_manager_.nodes_info(); | |||
| bool res = SendPrepareBuildingNetwork(node_infos); | |||
| if (!res) { | |||
| MS_LOG(ERROR) << "Prepare for building network failed! Cluster will exit later."; | |||
| return; | |||
| } | |||
| is_ready_ = true; | |||
| MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and " | |||
| << node_manager_.server_num() | |||
| << " servers registered to scheduer, so the scheduler send meta data to worker/server."; | |||
| } | |||
| node_manager_.UpdateNodesInfo(); | |||
| auto node_infos = node_manager_.nodes_info(); | |||
| bool res = SendPrepareBuildingNetwork(node_infos); | |||
| if (!res) { | |||
| MS_LOG(ERROR) << "Prepare for building network failed!"; | |||
| return false; | |||
| } | |||
| is_ready_ = true; | |||
| MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and " | |||
| << node_manager_.server_num() | |||
| << " servers registered to scheduer, so the scheduler send meta data to worker/server."; | |||
| for (const auto &kvs : node_infos) { | |||
| auto client = GetOrCreateClient(kvs.second); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| SendMetadata(client, kvs.second.rank_id_); | |||
| node_manager_.UpdateHeartbeat(kvs.first); | |||
| } | |||
| for (const auto &kvs : node_infos) { | |||
| auto client = GetOrCreateClient(kvs.second); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| SendMetadata(client, kvs.second.rank_id_); | |||
| node_manager_.UpdateHeartbeat(kvs.first); | |||
| } | |||
| if (node_manager_.GetClusterState() == ClusterState::CLUSTER_DISABLE_FLS) { | |||
| MS_LOG(WARNING) | |||
| << "Cluster state is CLUSTER_DISABLE_FLS, do not need to change to CLUSTER_READY when building network."; | |||
| } else { | |||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | |||
| PersistMetaData(); | |||
| wait_start_cond_.notify_all(); | |||
| } | |||
| PersistMetaData(); | |||
| wait_start_cond_.notify_all(); | |||
| return true; | |||
| } | |||
| void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| @@ -897,7 +932,7 @@ bool SchedulerNode::QueryNodeScaleState(const std::shared_ptr<HttpMessageHandler | |||
| auto client = GetOrCreateClient(kvs.second); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| MS_EXCEPTION_IF_NULL(instance_manager_); | |||
| instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, node_info_); | |||
| instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, kvs.second); | |||
| } | |||
| } | |||
| @@ -1179,7 +1214,7 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler> | |||
| auto client = GetOrCreateClient(kvs.second); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| MS_EXCEPTION_IF_NULL(instance_manager_); | |||
| instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, node_info_); | |||
| instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, kvs.second); | |||
| } | |||
| } | |||
| bool res = Wait(request_id); | |||
| @@ -1246,7 +1281,7 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandle | |||
| auto client = GetOrCreateClient(kvs.second); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| MS_EXCEPTION_IF_NULL(instance_manager_); | |||
| instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, node_info_); | |||
| instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, kvs.second); | |||
| } | |||
| } | |||
| bool res = Wait(request_id); | |||
| @@ -1308,7 +1343,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> & | |||
| auto client = GetOrCreateClient(kvs.second); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| MS_EXCEPTION_IF_NULL(instance_manager_); | |||
| instance_manager_->EnableFLSAsync(client, node_manager_, request_id, node_info_); | |||
| instance_manager_->EnableFLSAsync(client, node_manager_, request_id, kvs.second); | |||
| } | |||
| } | |||
| bool res = Wait(request_id); | |||
| @@ -1334,6 +1369,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> & | |||
| js["code"] = kSuccessCode; | |||
| js["result"] = true; | |||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | |||
| PersistMetaData(); | |||
| } else { | |||
| js["message"] = "start enabling FL-Server failed."; | |||
| js["code"] = kErrorCode; | |||
| @@ -1379,7 +1415,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> | |||
| auto client = GetOrCreateClient(kvs.second); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| MS_EXCEPTION_IF_NULL(instance_manager_); | |||
| instance_manager_->DisableFLSAsync(client, node_manager_, request_id, node_info_); | |||
| instance_manager_->DisableFLSAsync(client, node_manager_, request_id, kvs.second); | |||
| } | |||
| } | |||
| bool res = Wait(request_id); | |||
| @@ -1404,6 +1440,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> | |||
| js["code"] = kSuccessCode; | |||
| js["result"] = true; | |||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS); | |||
| PersistMetaData(); | |||
| } else { | |||
| js["message"] = "start disabling FL-Server failed."; | |||
| js["code"] = kErrorCode; | |||
| @@ -1565,6 +1602,7 @@ void SchedulerNode::PersistMetaData() { | |||
| return; | |||
| } | |||
| if (!is_ready_) { | |||
| MS_LOG(WARNING) << "Cluster is not building network successful, do not persist meta data"; | |||
| return; | |||
| } | |||
| if (config_->Exists(kKeyRecovery)) { | |||
| @@ -1576,6 +1614,7 @@ void SchedulerNode::PersistMetaData() { | |||
| clusterConfig.initial_next_server_rank_id = node_manager_.next_server_rank_id(); | |||
| clusterConfig.initial_registered_nodes_infos.clear(); | |||
| clusterConfig.initial_registered_nodes_infos = node_manager_.registered_nodes_info(); | |||
| clusterConfig.initial_cluster_state = node_manager_.GetClusterState(); | |||
| scheduler_recovery_->Persist(clusterConfig); | |||
| scheduler_recovery_->PersistNodesInfo(clusterConfig); | |||
| @@ -217,6 +217,8 @@ class BACKEND_EXPORT SchedulerNode : public Node { | |||
| void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, bool is_success, const std::string &error); | |||
| bool BuildingNetwork(); | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> scheduler_thread_; | |||
| std::unique_ptr<std::thread> update_state_thread_; | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "ps/core/scheduler_recovery.h" | |||
| #include "ps/core/comm_util.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -63,11 +64,6 @@ bool SchedulerRecovery::Recover() { | |||
| MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path(); | |||
| } | |||
| MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num | |||
| << ", the server num:" << clusterConfig.initial_server_num | |||
| << ", the scheduler ip:" << clusterConfig.scheduler_host | |||
| << ", the scheduler port:" << clusterConfig.scheduler_port; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(scheduler_recovery_storage_, false); | |||
| // 5. recover total node num | |||
| if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) { | |||
| @@ -110,7 +106,6 @@ bool SchedulerRecovery::Recover() { | |||
| node_info.port_ = static_cast<uint16_t>(std::strtol(port.c_str(), nullptr, kBase)); | |||
| node_info.node_id_ = elem.at("node_id"); | |||
| node_info.rank_id_ = UlongToUint(std::strtoul(rank_id.c_str(), nullptr, kBase)); | |||
| node_info.is_alive = CommUtil::StringToBool(elem.at("alive")); | |||
| node_info.node_role_ = CommUtil::StringToNodeRole(elem.at("role")); | |||
| nodes_infos[node_info.node_id_] = node_info; | |||
| @@ -127,18 +122,11 @@ bool SchedulerRecovery::Recover() { | |||
| MS_LOG(EXCEPTION) << kRecoveryRegisteredNodesInfos << " is not contained in " << recovery_storage_->file_path(); | |||
| } | |||
| MS_LOG(INFO) << ", the initial total node num:" << clusterConfig.initial_total_node_num | |||
| << ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id | |||
| << ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id; | |||
| if (!clusterConfig.initial_registered_nodes_infos.empty()) { | |||
| for (const auto kvs : clusterConfig.initial_registered_nodes_infos) { | |||
| MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_ | |||
| << ", the node_id:" << kvs.second.node_id_ | |||
| << ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_) | |||
| << ", the rank_id_:" << kvs.second.rank_id_ | |||
| << ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive); | |||
| } | |||
| // 9. recover cluster state | |||
| if (recovery_storage_->Exists(kRecoveryClusterState)) { | |||
| clusterConfig.initial_cluster_state = kClusterStateMap.at(recovery_storage_->GetString(kRecoveryClusterState, "")); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << kRecoveryClusterState << " is not contained in " << recovery_storage_->file_path(); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -209,8 +209,9 @@ void PSContext::set_rank_id(uint32_t rank_id) const { | |||
| } | |||
| void PSContext::set_server_mode(const std::string &server_mode) { | |||
| if (server_mode != kServerModePS && server_mode != kServerModeFL) { | |||
| MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL; | |||
| if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) { | |||
| MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL | |||
| << " or " << kServerModeHybrid; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it."; | |||