| @@ -293,6 +293,7 @@ inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string & | |||||
| // Definitions for Federated Learning. | // Definitions for Federated Learning. | ||||
| constexpr auto kNetworkError = "Cluster networking failed."; | constexpr auto kNetworkError = "Cluster networking failed."; | ||||
| constexpr auto KTriggerCounterEventError = "Cluster trigger counter event failed."; | |||||
| // The result code used for round kernels. | // The result code used for round kernels. | ||||
| enum class ResultCode { | enum class ResultCode { | ||||
| @@ -86,7 +86,7 @@ bool DistributedCountService::ReInitCounter(const std::string &name, size_t glob | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) { | |||||
| bool DistributedCountService::Count(const std::string &name, const std::string &id) { | |||||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; | MS_LOG(DEBUG) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; | ||||
| if (local_rank_ == counting_server_rank_) { | if (local_rank_ == counting_server_rank_) { | ||||
| if (global_threshold_count_.count(name) == 0) { | if (global_threshold_count_.count(name) == 0) { | ||||
| @@ -107,9 +107,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||||
| MS_LOG(INFO) << "Global current count for " << name << " is: " << global_current_count_[name].size() << "/" | MS_LOG(INFO) << "Global current count for " << name << " is: " << global_current_count_[name].size() << "/" | ||||
| << global_threshold_count_[name]; | << global_threshold_count_[name]; | ||||
| } | } | ||||
| if (!TriggerCounterEvent(name, reason)) { | |||||
| if (!TriggerCounterEvent(name)) { | |||||
| MS_LOG(WARNING) << "Leader server trigger count event failed."; | MS_LOG(WARNING) << "Leader server trigger count event failed."; | ||||
| Iteration::GetInstance().NotifyNext(false, *reason); | |||||
| Iteration::GetInstance().NotifyNext(false, KTriggerCounterEventError); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -121,10 +121,8 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||||
| std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr; | std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr; | ||||
| if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount, | if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount, | ||||
| &report_cnt_rsp_msg)) { | &report_cnt_rsp_msg)) { | ||||
| MS_LOG(WARNING) << "Sending reporting count message to leader server failed for " << name; | |||||
| if (reason != nullptr) { | |||||
| *reason = kNetworkError; | |||||
| } | |||||
| MS_LOG(WARNING) << "Sending reporting count " + name + " message to leader server failed for fl id " << id; | |||||
| Iteration::GetInstance().NotifyNext(false, kNetworkError); | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -133,10 +131,6 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||||
| (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); | (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); | ||||
| if (!count_rsp.result()) { | if (!count_rsp.result()) { | ||||
| MS_LOG(WARNING) << "Reporting count failed:" << count_rsp.reason(); | MS_LOG(WARNING) << "Reporting count failed:" << count_rsp.reason(); | ||||
| // If the error is caused by the network issue, return the reason. | |||||
| if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { | |||||
| *reason = kNetworkError; | |||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -263,9 +257,8 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core: | |||||
| MS_LOG(WARNING) << "Sending response failed."; | MS_LOG(WARNING) << "Sending response failed."; | ||||
| return; | return; | ||||
| } | } | ||||
| std::string reason = "success"; | |||||
| if (!TriggerCounterEvent(name, &reason)) { | |||||
| Iteration::GetInstance().NotifyNext(false, reason); | |||||
| if (!TriggerCounterEvent(name)) { | |||||
| Iteration::GetInstance().NotifyNext(false, KTriggerCounterEventError); | |||||
| } | } | ||||
| } | } | ||||
| @@ -322,7 +315,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core: | |||||
| return; | return; | ||||
| } | } | ||||
| bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) { | |||||
| bool DistributedCountService::TriggerCounterEvent(const std::string &name) { | |||||
| if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) { | if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) { | ||||
| MS_LOG(WARNING) << "The counter of " << name << " is not registered."; | MS_LOG(WARNING) << "The counter of " << name << " is not registered."; | ||||
| return false; | return false; | ||||
| @@ -332,19 +325,19 @@ bool DistributedCountService::TriggerCounterEvent(const std::string &name, std:: | |||||
| << ", threshold count is " << global_threshold_count_[name]; | << ", threshold count is " << global_threshold_count_[name]; | ||||
| // The threshold count may be 1 so the first and last count event should be both activated. | // The threshold count may be 1 so the first and last count event should be both activated. | ||||
| if (global_current_count_[name].size() == 1) { | if (global_current_count_[name].size() == 1) { | ||||
| if (!TriggerFirstCountEvent(name, reason)) { | |||||
| if (!TriggerFirstCountEvent(name)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| if (global_current_count_[name].size() == global_threshold_count_[name]) { | if (global_current_count_[name].size() == global_threshold_count_[name]) { | ||||
| if (!TriggerLastCountEvent(name, reason)) { | |||||
| if (!TriggerLastCountEvent(name)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, std::string *reason) { | |||||
| bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) { | |||||
| MS_LOG(DEBUG) << "Activating first count event for " << name; | MS_LOG(DEBUG) << "Activating first count event for " << name; | ||||
| CounterEvent first_count_event; | CounterEvent first_count_event; | ||||
| first_count_event.set_type(CounterEventType::FIRST_CNT); | first_count_event.set_type(CounterEventType::FIRST_CNT); | ||||
| @@ -354,10 +347,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st | |||||
| for (uint32_t i = 1; i < server_num_; i++) { | for (uint32_t i = 1; i < server_num_; i++) { | ||||
| MS_LOG(DEBUG) << "Start sending first count event message to server " << i; | MS_LOG(DEBUG) << "Start sending first count event message to server " << i; | ||||
| if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | ||||
| MS_LOG(WARNING) << "Activating first count event to server " << i << " failed."; | |||||
| if (reason != nullptr) { | |||||
| *reason = kNetworkError; | |||||
| } | |||||
| MS_LOG(WARNING) << "Send activating first count event to server " << i << " failed."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -374,7 +364,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) { | |||||
| bool DistributedCountService::TriggerLastCountEvent(const std::string &name) { | |||||
| MS_LOG(DEBUG) << "Activating last count event for " << name; | MS_LOG(DEBUG) << "Activating last count event for " << name; | ||||
| CounterEvent last_count_event; | CounterEvent last_count_event; | ||||
| last_count_event.set_type(CounterEventType::LAST_CNT); | last_count_event.set_type(CounterEventType::LAST_CNT); | ||||
| @@ -384,10 +374,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std | |||||
| for (uint32_t i = 1; i < server_num_; i++) { | for (uint32_t i = 1; i < server_num_; i++) { | ||||
| MS_LOG(DEBUG) << "Start sending last count event message to server " << i; | MS_LOG(DEBUG) << "Start sending last count event message to server " << i; | ||||
| if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | ||||
| MS_LOG(WARNING) << "Activating last count event to server " << i << " failed."; | |||||
| if (reason != nullptr) { | |||||
| *reason = kNetworkError; | |||||
| } | |||||
| MS_LOG(WARNING) << "Send activating last count event to server " << i << " failed."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -67,9 +67,8 @@ class DistributedCountService { | |||||
| // Reinitialize counter due to the change of threshold count. | // Reinitialize counter due to the change of threshold count. | ||||
| bool ReInitCounter(const std::string &name, size_t global_threshold_count); | bool ReInitCounter(const std::string &name, size_t global_threshold_count); | ||||
| // Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the | |||||
| // reason why counting failed. | |||||
| bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr); | |||||
| // Report a count to the counting server. Parameter 'id' is in case of repeated counting. | |||||
| bool Count(const std::string &name, const std::string &id); | |||||
| // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, | // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, | ||||
| // this method returns true. | // this method returns true. | ||||
| @@ -103,9 +102,9 @@ class DistributedCountService { | |||||
| void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message); | void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message); | ||||
| // Call the callbacks when the first/last count event is triggered. | // Call the callbacks when the first/last count event is triggered. | ||||
| bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr); | |||||
| bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr); | |||||
| bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr); | |||||
| bool TriggerCounterEvent(const std::string &name); | |||||
| bool TriggerFirstCountEvent(const std::string &name); | |||||
| bool TriggerLastCountEvent(const std::string &name); | |||||
| // Members for the communication between counting server and other servers. | // Members for the communication between counting server and other servers. | ||||
| std::shared_ptr<ps::core::ServerNode> server_node_; | std::shared_ptr<ps::core::ServerNode> server_node_; | ||||
| @@ -86,7 +86,7 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) { | |||||
| return; | return; | ||||
| } | } | ||||
| bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) { | |||||
| bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { | |||||
| if (router_ == nullptr) { | if (router_ == nullptr) { | ||||
| MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; | MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; | ||||
| return false; | return false; | ||||
| @@ -107,10 +107,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM | |||||
| if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata, | if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata, | ||||
| &update_meta_rsp_msg)) { | &update_meta_rsp_msg)) { | ||||
| MS_LOG(WARNING) << "Sending updating metadata message to server " << stored_rank << " failed."; | MS_LOG(WARNING) << "Sending updating metadata message to server " << stored_rank << " failed."; | ||||
| if (reason != nullptr) { | |||||
| *reason = kNetworkError; | |||||
| } | |||||
| Iteration::GetInstance().NotifyNext(false, *reason); | |||||
| Iteration::GetInstance().NotifyNext(false, kNetworkError); | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -55,8 +55,8 @@ class DistributedMetadataStore { | |||||
| // Reset the metadata value for the name. | // Reset the metadata value for the name. | ||||
| void ResetMetadata(const std::string &name); | void ResetMetadata(const std::string &name); | ||||
| // Update the metadata for the name. Parameter 'reason' is the reason why updating meta data failed. | |||||
| bool UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason = nullptr); | |||||
| // Update the metadata for the name. | |||||
| bool UpdateMetadata(const std::string &name, const PBMetadata &meta); | |||||
| // Get the metadata for the name. | // Get the metadata for the name. | ||||
| PBMetadata GetMetadata(const std::string &name); | PBMetadata GetMetadata(const std::string &name); | ||||
| @@ -153,9 +153,8 @@ bool GetListSignKernel::Launch(const uint8_t *req_data, size_t len, | |||||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | ||||
| return true; | return true; | ||||
| } | } | ||||
| std::string count_reason = ""; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { | |||||
| std::string reason = "Counting for get list sign request failed. Please retry later. " + count_reason; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id)) { | |||||
| std::string reason = "Counting for get list sign request failed for fl id " + fl_id + ". Please retry later. "; | |||||
| BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), | BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), | ||||
| iter_num, list_signs); | iter_num, list_signs); | ||||
| MS_LOG(ERROR) << reason; | MS_LOG(ERROR) << reason; | ||||
| @@ -128,9 +128,8 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign | |||||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | ||||
| return true; | return true; | ||||
| } | } | ||||
| std::string count_reason = ""; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { | |||||
| std::string reason = "Counting for push list sign request failed. Please retry later. " + count_reason; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, fl_id)) { | |||||
| std::string reason = "Counting for push list sign request failed for fl id " + fl_id + ". Please retry later."; | |||||
| BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), | BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), | ||||
| iter_num); | iter_num); | ||||
| MS_LOG(ERROR) << reason; | MS_LOG(ERROR) << reason; | ||||
| @@ -85,8 +85,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb, | |||||
| Iteration::GetInstance().set_loss(loss); | Iteration::GetInstance().set_loss(loss); | ||||
| Iteration::GetInstance().set_accuracy(accuracy); | Iteration::GetInstance().set_accuracy(accuracy); | ||||
| std::string count_reason = ""; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) { | |||||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { | |||||
| std::string reason = "Count for push metrics request failed."; | std::string reason = "Count for push metrics request failed."; | ||||
| BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError); | BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError); | ||||
| MS_LOG(ERROR) << reason; | MS_LOG(ERROR) << reason; | ||||
| @@ -110,8 +110,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||||
| } | } | ||||
| MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds."; | MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds."; | ||||
| std::string count_reason = ""; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) { | |||||
| if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { | |||||
| std::string reason = "Count for push weight request failed."; | std::string reason = "Count for push weight request failed."; | ||||
| BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); | BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); | ||||
| MS_LOG(ERROR) << reason; | MS_LOG(ERROR) << reason; | ||||
| @@ -107,8 +107,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||||
| } | } | ||||
| PBMetadata metadata; | PBMetadata metadata; | ||||
| *metadata.mutable_device_meta() = device_meta; | *metadata.mutable_device_meta() = device_meta; | ||||
| std::string update_reason = ""; | |||||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) { | |||||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) { | |||||
| std::string reason = "Updating device metadata failed for fl id " + device_meta.fl_id(); | std::string reason = "Updating device metadata failed for fl id " + device_meta.fl_id(); | ||||
| BuildStartFLJobRsp( | BuildStartFLJobRsp( | ||||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | fbb, schema::ResponseCode_OutOfTime, reason, false, | ||||
| @@ -116,7 +115,6 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. | // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. | ||||
| result_code = CountForStartFLJob(fbb, start_fl_job_req); | result_code = CountForStartFLJob(fbb, start_fl_job_req); | ||||
| if (result_code != ResultCode::kSuccess) { | if (result_code != ResultCode::kSuccess) { | ||||
| @@ -299,8 +297,7 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> | |||||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail); | MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail); | ||||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail); | MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail); | ||||
| std::string count_reason = ""; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) { | |||||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { | |||||
| std::string reason = | std::string reason = | ||||
| "Counting start fl job request failed for fl id " + start_fl_job_req->fl_id()->str() + ", Please retry later."; | "Counting start fl job request failed for fl id " + start_fl_job_req->fl_id()->str() + ", Please retry later."; | ||||
| BuildStartFLJobRsp( | BuildStartFLJobRsp( | ||||
| @@ -365,8 +365,7 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||||
| fl_id.set_fl_id(update_model_fl_id); | fl_id.set_fl_id(update_model_fl_id); | ||||
| PBMetadata comm_value; | PBMetadata comm_value; | ||||
| *comm_value.mutable_fl_id() = fl_id; | *comm_value.mutable_fl_id() = fl_id; | ||||
| std::string update_reason = ""; | |||||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) { | |||||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) { | |||||
| std::string reason = "Updating metadata of UpdateModelClientList failed for fl id " + update_model_fl_id; | std::string reason = "Updating metadata of UpdateModelClientList failed for fl id " + update_model_fl_id; | ||||
| BuildUpdateModelRsp( | BuildUpdateModelRsp( | ||||
| fbb, schema::ResponseCode_OutOfTime, reason, | fbb, schema::ResponseCode_OutOfTime, reason, | ||||
| @@ -526,9 +525,8 @@ std::map<std::string, UploadData> UpdateModelKernel::DecodeFeatureMap( | |||||
| } | } | ||||
| ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) { | ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) { | ||||
| std::string count_reason = ""; | |||||
| if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) { | |||||
| MS_LOG(ERROR) << "Counting for aggregation failed. reason: " + count_reason; | |||||
| if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id)) { | |||||
| MS_LOG(ERROR) << "Counting for aggregation failed for fl id " << req_fl_id; | |||||
| return ResultCode::kFail; | return ResultCode::kFail; | ||||
| } | } | ||||
| return ResultCode::kSuccess; | return ResultCode::kSuccess; | ||||
| @@ -538,10 +536,9 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilde | |||||
| const schema::RequestUpdateModel *update_model_req) { | const schema::RequestUpdateModel *update_model_req) { | ||||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail); | MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail); | ||||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | ||||
| std::string count_reason = ""; | |||||
| if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) { | |||||
| if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) { | |||||
| std::string reason = "Counting for update model request failed for fl id " + update_model_req->fl_id()->str() + | std::string reason = "Counting for update model request failed for fl id " + update_model_req->fl_id()->str() + | ||||
| ", Please retry later. " + count_reason; | |||||
| ", Please retry later."; | |||||
| BuildUpdateModelRsp( | BuildUpdateModelRsp( | ||||
| fbb, schema::ResponseCode_OutOfTime, reason, | fbb, schema::ResponseCode_OutOfTime, reason, | ||||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | ||||
| @@ -142,6 +142,7 @@ constexpr char kRecoveryTotalNodeNum[] = "total_node_num"; | |||||
| constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id"; | constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id"; | ||||
| constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id"; | constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id"; | ||||
| constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids"; | constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids"; | ||||
| constexpr char kRecoveryClusterState[] = "cluster_state"; | |||||
| constexpr char kServerCertPath[] = "server_cert_path"; | constexpr char kServerCertPath[] = "server_cert_path"; | ||||
| constexpr char kServerPassword[] = "server_password"; | constexpr char kServerPassword[] = "server_password"; | ||||
| @@ -46,7 +46,8 @@ struct ClusterConfig { | |||||
| scheduler_timeout(30), | scheduler_timeout(30), | ||||
| initial_total_node_num(0), | initial_total_node_num(0), | ||||
| initial_next_worker_rank_id(0), | initial_next_worker_rank_id(0), | ||||
| initial_next_server_rank_id(0) {} | |||||
| initial_next_server_rank_id(0), | |||||
| initial_cluster_state(ClusterState::CLUSTER_STARTING) {} | |||||
| // Configure through environment variables:MS_WORKER_NUM | // Configure through environment variables:MS_WORKER_NUM | ||||
| uint32_t initial_worker_num; | uint32_t initial_worker_num; | ||||
| // Configure through environment variables:MS_SERVER_NUM | // Configure through environment variables:MS_SERVER_NUM | ||||
| @@ -72,6 +73,7 @@ struct ClusterConfig { | |||||
| uint32_t initial_total_node_num; | uint32_t initial_total_node_num; | ||||
| uint32_t initial_next_worker_rank_id; | uint32_t initial_next_worker_rank_id; | ||||
| uint32_t initial_next_server_rank_id; | uint32_t initial_next_server_rank_id; | ||||
| ClusterState initial_cluster_state; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -58,6 +58,7 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| @@ -99,6 +100,19 @@ const std::vector<std::string> kClusterState = { | |||||
| "CLUSTER_SCALE_OUT_ROLLBACK", // When the cluster is scale out rollback. | "CLUSTER_SCALE_OUT_ROLLBACK", // When the cluster is scale out rollback. | ||||
| }; | }; | ||||
| const std::map<std::string, ClusterState> kClusterStateMap = { | |||||
| {"CLUSTER_STARTING", ClusterState::CLUSTER_STARTING}, | |||||
| {"CLUSTER_READY", ClusterState::CLUSTER_READY}, | |||||
| {"CLUSTER_EXIT", ClusterState::CLUSTER_EXIT}, | |||||
| {"NODE_TIMEOUT", ClusterState::NODE_TIMEOUT}, | |||||
| {"CLUSTER_SCALE_OUT", ClusterState::CLUSTER_SCALE_OUT}, | |||||
| {"CLUSTER_SCALE_IN", ClusterState::CLUSTER_SCALE_IN}, | |||||
| {"CLUSTER_NEW_INSTANCE", ClusterState::CLUSTER_NEW_INSTANCE}, | |||||
| {"CLUSTER_ENABLE_FLS", ClusterState::CLUSTER_ENABLE_FLS}, | |||||
| {"CLUSTER_DISABLE_FLS", ClusterState::CLUSTER_DISABLE_FLS}, | |||||
| {"CLUSTER_SCHEDULER_RECOVERY", ClusterState::CLUSTER_SCHEDULER_RECOVERY}, | |||||
| {"CLUSTER_SCALE_OUT_ROLLBACK", ClusterState::CLUSTER_SCALE_OUT_ROLLBACK}}; | |||||
| class CommUtil { | class CommUtil { | ||||
| public: | public: | ||||
| static bool CheckIpWithRegex(const std::string &ip); | static bool CheckIpWithRegex(const std::string &ip); | ||||
| @@ -117,7 +117,6 @@ void FileConfiguration::PersistNodes(const core::ClusterConfig &clusterConfig) c | |||||
| res["node_id"] = node_info.node_id_; | res["node_id"] = node_info.node_id_; | ||||
| res["rank_id"] = std::to_string(node_info.rank_id_); | res["rank_id"] = std::to_string(node_info.rank_id_); | ||||
| res["role"] = CommUtil::NodeRoleToString(node_info.node_role_); | res["role"] = CommUtil::NodeRoleToString(node_info.node_role_); | ||||
| res["alive"] = CommUtil::BoolToString(node_info.is_alive); | |||||
| persist_js["node_ids"].push_back(res); | persist_js["node_ids"].push_back(res); | ||||
| } | } | ||||
| @@ -138,6 +137,7 @@ void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) co | |||||
| persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num; | persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num; | ||||
| persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host; | persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host; | ||||
| persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port; | persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port; | ||||
| persist_js[kRecoveryClusterState] = CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state); | |||||
| std::ofstream output_file(file_path_); | std::ofstream output_file(file_path_); | ||||
| output_file << persist_js.dump(); | output_file << persist_js.dump(); | ||||
| @@ -225,6 +225,8 @@ std::string FollowerScaler::GetNodeScaleStateStr() { | |||||
| return "kWaiting"; | return "kWaiting"; | ||||
| case NodeScaleState::kScaling: | case NodeScaleState::kScaling: | ||||
| return "kScaling"; | return "kScaling"; | ||||
| case NodeScaleState::kRollback: | |||||
| return "kRollback"; | |||||
| default: | default: | ||||
| MS_LOG(EXCEPTION) << "scale_state is not supported."; | MS_LOG(EXCEPTION) << "scale_state is not supported."; | ||||
| } | } | ||||
| @@ -35,7 +35,7 @@ void InstanceManager::NewInstanceAsync(const std::shared_ptr<TcpClient> &client, | |||||
| MS_LOG(WARNING) << "Send new instance timeout!"; | MS_LOG(WARNING) << "Send new instance timeout!"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The scheduler is sending new instance to workers and servers!"; | |||||
| MS_LOG(INFO) << "The scheduler is sending new instance to " << node_info.node_id_; | |||||
| } | } | ||||
| void InstanceManager::QueryInstanceAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | void InstanceManager::QueryInstanceAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | ||||
| @@ -55,7 +55,7 @@ void InstanceManager::QueryInstanceAsync(const std::shared_ptr<TcpClient> &clien | |||||
| MS_LOG(WARNING) << "Send query instance timeout!"; | MS_LOG(WARNING) << "Send query instance timeout!"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; | |||||
| MS_LOG(INFO) << "The scheduler is sending query instance to " << node_info.node_id_; | |||||
| } | } | ||||
| void InstanceManager::EnableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | void InstanceManager::EnableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | ||||
| @@ -75,7 +75,7 @@ void InstanceManager::EnableFLSAsync(const std::shared_ptr<TcpClient> &client, c | |||||
| MS_LOG(WARNING) << "Send query instance timeout!"; | MS_LOG(WARNING) << "Send query instance timeout!"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; | |||||
| MS_LOG(INFO) << "The scheduler is sending enable FLS to " << node_info.node_id_; | |||||
| } | } | ||||
| void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &, | ||||
| @@ -95,7 +95,7 @@ void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client, | |||||
| MS_LOG(WARNING) << "Send query instance timeout!"; | MS_LOG(WARNING) << "Send query instance timeout!"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; | |||||
| MS_LOG(INFO) << "The scheduler is sending disable FLS to " << node_info.node_id_; | |||||
| } | } | ||||
| void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &client, const NodeManager &, | void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &client, const NodeManager &, | ||||
| @@ -115,7 +115,7 @@ void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &clie | |||||
| MS_LOG(WARNING) << "Send query node scale state timeout!"; | MS_LOG(WARNING) << "Send query node scale state timeout!"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The scheduler is sending query node scale state to workers and servers!"; | |||||
| MS_LOG(INFO) << "The scheduler is sending query node scale state to " << node_info.node_id_; | |||||
| } | } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -54,7 +54,6 @@ class BACKEND_EXPORT Node { | |||||
| is_already_stopped_(true), | is_already_stopped_(true), | ||||
| is_already_finished_(false), | is_already_finished_(false), | ||||
| next_request_id_(0), | next_request_id_(0), | ||||
| current_node_state_(NodeState::NODE_STARTING), | |||||
| current_cluster_state_(ClusterState::CLUSTER_STARTING) {} | current_cluster_state_(ClusterState::CLUSTER_STARTING) {} | ||||
| virtual ~Node() = default; | virtual ~Node() = default; | ||||
| @@ -113,8 +112,6 @@ class BACKEND_EXPORT Node { | |||||
| std::mutex message_tracker_mutex_; | std::mutex message_tracker_mutex_; | ||||
| std::condition_variable message_tracker_cond_; | std::condition_variable message_tracker_cond_; | ||||
| // Worker and server receive the node state and cluster state from the scheduler. | |||||
| NodeState current_node_state_; | |||||
| ClusterState current_cluster_state_; | ClusterState current_cluster_state_; | ||||
| // Configuration file,The format is as follows | // Configuration file,The format is as follows | ||||
| @@ -70,8 +70,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message | |||||
| registered_nodes_info_[node_id] = recovery_node_infos[node_id]; | registered_nodes_info_[node_id] = recovery_node_infos[node_id]; | ||||
| MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!" | MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!" | ||||
| << ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_ | << ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_ | ||||
| << ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive | |||||
| << ", fl iteration num: " << new_fl_iteration_num | |||||
| << ", rank id: " << rank_id << ", fl iteration num: " << new_fl_iteration_num | |||||
| << ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_); | << ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_); | ||||
| return rank_id; | return rank_id; | ||||
| } | } | ||||
| @@ -235,12 +234,15 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) { | |||||
| timeout_nodes_info_[it->first] = registered_nodes_info_[it->first]; | timeout_nodes_info_[it->first] = registered_nodes_info_[it->first]; | ||||
| registered_nodes_info_[it->first].is_alive = false; | registered_nodes_info_[it->first].is_alive = false; | ||||
| } | } | ||||
| } else { | |||||
| if (registered_nodes_info_.count(it->first) && !registered_nodes_info_[it->first].is_alive) { | |||||
| MS_LOG(WARNING) << registered_nodes_info_[it->first].node_id_ << " is alive."; | |||||
| registered_nodes_info_[it->first].is_alive = true; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (!timeout_nodes_info_.empty()) { | if (!timeout_nodes_info_.empty()) { | ||||
| UpdateClusterState(ClusterState::NODE_TIMEOUT); | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) { | if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) { | ||||
| @@ -249,25 +251,14 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) { | |||||
| finish_nodes_id_.insert(iter->first); | finish_nodes_id_.insert(iter->first); | ||||
| } | } | ||||
| } | } | ||||
| if (onPersist_) { | |||||
| onPersist_(); | |||||
| if (cluster_state_ != ClusterState::CLUSTER_DISABLE_FLS) { | |||||
| UpdateClusterState(ClusterState::NODE_TIMEOUT); | |||||
| } | } | ||||
| } else if (SizeToUint(heartbeats_.size()) == total_node_num_) { | |||||
| if (cluster_state_ == ClusterState::NODE_TIMEOUT) { | |||||
| for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) { | |||||
| if (registered_nodes_info_.count(it->first) && !it->second.is_alive) { | |||||
| MS_LOG(WARNING) << it->second.node_id_ << " is alive."; | |||||
| it->second.is_alive = true; | |||||
| } | |||||
| } | |||||
| if (onPersist_) { | |||||
| onPersist_(); | |||||
| } | |||||
| if (is_cluster_ready) { | |||||
| UpdateClusterState(ClusterState::CLUSTER_READY); | |||||
| } else { | |||||
| UpdateClusterState(ClusterState::CLUSTER_STARTING); | |||||
| } | |||||
| } else if (SizeToUint(heartbeats_.size()) == total_node_num_ && cluster_state_ == ClusterState::NODE_TIMEOUT) { | |||||
| if (is_cluster_ready) { | |||||
| UpdateClusterState(ClusterState::CLUSTER_READY); | |||||
| } else { | |||||
| UpdateClusterState(ClusterState::CLUSTER_STARTING); | |||||
| } | } | ||||
| } | } | ||||
| @@ -324,11 +315,6 @@ void NodeManager::UpdateNodesInfo() { | |||||
| nodes_info_ = registered_nodes_info_; | nodes_info_ = registered_nodes_info_; | ||||
| } | } | ||||
| void NodeManager::UpdateNodeState(const NodeState &state) { | |||||
| std::lock_guard<std::mutex> lk(node_mutex_); | |||||
| node_state_ = state; | |||||
| } | |||||
| void NodeManager::UpdateClusterState(const ClusterState &state) { | void NodeManager::UpdateClusterState(const ClusterState &state) { | ||||
| std::lock_guard<std::mutex> lk(cluster_mutex_); | std::lock_guard<std::mutex> lk(cluster_mutex_); | ||||
| std::string state_str = CommUtil::ClusterStateToString(state); | std::string state_str = CommUtil::ClusterStateToString(state); | ||||
| @@ -340,11 +326,6 @@ void NodeManager::UpdateClusterState(const ClusterState &state) { | |||||
| cluster_state_ = state; | cluster_state_ = state; | ||||
| } | } | ||||
| NodeState NodeManager::GetNodeState() { | |||||
| std::lock_guard<std::mutex> lk(node_mutex_); | |||||
| return node_state_; | |||||
| } | |||||
| ClusterState NodeManager::GetClusterState() { | ClusterState NodeManager::GetClusterState() { | ||||
| std::lock_guard<std::mutex> lk(cluster_mutex_); | std::lock_guard<std::mutex> lk(cluster_mutex_); | ||||
| return cluster_state_; | return cluster_state_; | ||||
| @@ -49,7 +49,6 @@ class NodeManager { | |||||
| next_worker_rank_id_(0), | next_worker_rank_id_(0), | ||||
| next_server_rank_id_(0), | next_server_rank_id_(0), | ||||
| meta_data_(nullptr), | meta_data_(nullptr), | ||||
| node_state_(NodeState::NODE_STARTING), | |||||
| cluster_state_(ClusterState::CLUSTER_STARTING) {} | cluster_state_(ClusterState::CLUSTER_STARTING) {} | ||||
| virtual ~NodeManager() = default; | virtual ~NodeManager() = default; | ||||
| using OnPersist = std::function<void()>; | using OnPersist = std::function<void()>; | ||||
| @@ -104,9 +103,7 @@ class NodeManager { | |||||
| uint32_t next_worker_rank_id() const; | uint32_t next_worker_rank_id() const; | ||||
| uint32_t next_server_rank_id() const; | uint32_t next_server_rank_id() const; | ||||
| void UpdateNodeState(const NodeState &state); | |||||
| void UpdateClusterState(const ClusterState &state); | void UpdateClusterState(const ClusterState &state); | ||||
| NodeState GetNodeState(); | |||||
| ClusterState GetClusterState(); | ClusterState GetClusterState(); | ||||
| // When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes | // When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes | ||||
| @@ -179,7 +176,6 @@ class NodeManager { | |||||
| // Cluster metadata information can be dynamically changed | // Cluster metadata information can be dynamically changed | ||||
| std::unique_ptr<ClusterMetadata> meta_data_; | std::unique_ptr<ClusterMetadata> meta_data_; | ||||
| NodeState node_state_; | |||||
| ClusterState cluster_state_; | ClusterState cluster_state_; | ||||
| std::deque<uint32_t> recovery_worker_rank_id_; | std::deque<uint32_t> recovery_worker_rank_id_; | ||||
| @@ -57,7 +57,6 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||||
| MS_LOG(ERROR) << "Start Scheduler node timeout!"; | MS_LOG(ERROR) << "Start Scheduler node timeout!"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | |||||
| StartUpdatePersistentCommandTimer(); | StartUpdatePersistentCommandTimer(); | ||||
| MS_LOG(INFO) << "[Scheduler start]: 4. Successfully start scheduler, there are " << node_manager_.worker_num() | MS_LOG(INFO) << "[Scheduler start]: 4. Successfully start scheduler, there are " << node_manager_.worker_num() | ||||
| @@ -85,7 +84,26 @@ void SchedulerNode::RunRecovery() { | |||||
| MS_LOG(WARNING) << "There is no registered nodes in scheduler!"; | MS_LOG(WARNING) << "There is no registered nodes in scheduler!"; | ||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The scheduler start run recovery!"; | |||||
| MS_LOG(INFO) << "The scheduler start run recovery!" | |||||
| << " The worker num:" << clusterConfig.initial_worker_num | |||||
| << ", the server num:" << clusterConfig.initial_server_num | |||||
| << ", the scheduler ip:" << clusterConfig.scheduler_host | |||||
| << ", the scheduler port:" << clusterConfig.scheduler_port | |||||
| << ", the initial total node num:" << clusterConfig.initial_total_node_num | |||||
| << ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id | |||||
| << ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id | |||||
| << ", the initial cluster state:" << kClusterState.at(clusterConfig.initial_cluster_state); | |||||
| if (!clusterConfig.initial_registered_nodes_infos.empty()) { | |||||
| for (const auto kvs : clusterConfig.initial_registered_nodes_infos) { | |||||
| MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_ | |||||
| << ", the node_id:" << kvs.second.node_id_ | |||||
| << ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_) | |||||
| << ", the rank_id_:" << kvs.second.rank_id_ | |||||
| << ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive); | |||||
| } | |||||
| } | |||||
| uint32_t worker_num = clusterConfig.initial_worker_num; | uint32_t worker_num = clusterConfig.initial_worker_num; | ||||
| uint32_t server_num = clusterConfig.initial_server_num; | uint32_t server_num = clusterConfig.initial_server_num; | ||||
| @@ -94,6 +112,11 @@ void SchedulerNode::RunRecovery() { | |||||
| node_manager_.set_next_worker_rank_id(clusterConfig.initial_next_worker_rank_id); | node_manager_.set_next_worker_rank_id(clusterConfig.initial_next_worker_rank_id); | ||||
| node_manager_.set_next_server_rank_id(clusterConfig.initial_next_server_rank_id); | node_manager_.set_next_server_rank_id(clusterConfig.initial_next_server_rank_id); | ||||
| node_manager_.set_total_node_num(clusterConfig.initial_total_node_num); | node_manager_.set_total_node_num(clusterConfig.initial_total_node_num); | ||||
| if (clusterConfig.initial_cluster_state == ClusterState::CLUSTER_DISABLE_FLS) { | |||||
| MS_LOG(WARNING) << "Scheduler recover and update cluster state from recovery file, cluster state is " | |||||
| << CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state); | |||||
| node_manager_.UpdateClusterState(clusterConfig.initial_cluster_state); | |||||
| } | |||||
| for (const auto &kvs : initial_node_infos) { | for (const auto &kvs : initial_node_infos) { | ||||
| auto &node_id = kvs.first; | auto &node_id = kvs.first; | ||||
| @@ -352,44 +375,56 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server, | |||||
| "will exit later."; | "will exit later."; | ||||
| return; | return; | ||||
| } | } | ||||
| if (!BuildingNetwork()) { | |||||
| MS_LOG(ERROR) << "Building network failed! Cluster will exit later."; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { | |||||
| auto nodes = node_manager_.nodes_info(); | |||||
| for (const auto &id : scale_in_node_ids_) { | |||||
| MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id; | |||||
| if (nodes.count(id)) { | |||||
| auto scale_in_client = GetOrCreateClient(nodes[id]); | |||||
| SendMetadata(scale_in_client, nodes[id].rank_id_); | |||||
| node_manager_.UpdateHeartbeat(id); | |||||
| } | |||||
| if (connected_nodes_.count(id)) { | |||||
| MS_LOG(INFO) << "remove scale in node id: " << id << " connection."; | |||||
| connected_nodes_.erase(id); | |||||
| } | |||||
| bool SchedulerNode::BuildingNetwork() { | |||||
| if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { | |||||
| auto nodes = node_manager_.nodes_info(); | |||||
| for (const auto &id : scale_in_node_ids_) { | |||||
| MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id; | |||||
| if (nodes.count(id)) { | |||||
| auto scale_in_client = GetOrCreateClient(nodes[id]); | |||||
| SendMetadata(scale_in_client, nodes[id].rank_id_); | |||||
| node_manager_.UpdateHeartbeat(id); | |||||
| } | |||||
| if (connected_nodes_.count(id)) { | |||||
| MS_LOG(INFO) << "remove scale in node id: " << id << " connection."; | |||||
| connected_nodes_.erase(id); | |||||
| } | } | ||||
| } | } | ||||
| node_manager_.UpdateNodesInfo(); | |||||
| auto node_infos = node_manager_.nodes_info(); | |||||
| bool res = SendPrepareBuildingNetwork(node_infos); | |||||
| if (!res) { | |||||
| MS_LOG(ERROR) << "Prepare for building network failed! Cluster will exit later."; | |||||
| return; | |||||
| } | |||||
| is_ready_ = true; | |||||
| MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and " | |||||
| << node_manager_.server_num() | |||||
| << " servers registered to scheduer, so the scheduler send meta data to worker/server."; | |||||
| } | |||||
| node_manager_.UpdateNodesInfo(); | |||||
| auto node_infos = node_manager_.nodes_info(); | |||||
| bool res = SendPrepareBuildingNetwork(node_infos); | |||||
| if (!res) { | |||||
| MS_LOG(ERROR) << "Prepare for building network failed!"; | |||||
| return false; | |||||
| } | |||||
| is_ready_ = true; | |||||
| MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and " | |||||
| << node_manager_.server_num() | |||||
| << " servers registered to scheduer, so the scheduler send meta data to worker/server."; | |||||
| for (const auto &kvs : node_infos) { | |||||
| auto client = GetOrCreateClient(kvs.second); | |||||
| MS_EXCEPTION_IF_NULL(client); | |||||
| SendMetadata(client, kvs.second.rank_id_); | |||||
| node_manager_.UpdateHeartbeat(kvs.first); | |||||
| } | |||||
| for (const auto &kvs : node_infos) { | |||||
| auto client = GetOrCreateClient(kvs.second); | |||||
| MS_EXCEPTION_IF_NULL(client); | |||||
| SendMetadata(client, kvs.second.rank_id_); | |||||
| node_manager_.UpdateHeartbeat(kvs.first); | |||||
| } | |||||
| if (node_manager_.GetClusterState() == ClusterState::CLUSTER_DISABLE_FLS) { | |||||
| MS_LOG(WARNING) | |||||
| << "Cluster state is CLUSTER_DISABLE_FLS, do not need to change to CLUSTER_READY when building network."; | |||||
| } else { | |||||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | ||||
| PersistMetaData(); | |||||
| wait_start_cond_.notify_all(); | |||||
| } | } | ||||
| PersistMetaData(); | |||||
| wait_start_cond_.notify_all(); | |||||
| return true; | |||||
| } | } | ||||
| void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | ||||
| @@ -897,7 +932,7 @@ bool SchedulerNode::QueryNodeScaleState(const std::shared_ptr<HttpMessageHandler | |||||
| auto client = GetOrCreateClient(kvs.second); | auto client = GetOrCreateClient(kvs.second); | ||||
| MS_EXCEPTION_IF_NULL(client); | MS_EXCEPTION_IF_NULL(client); | ||||
| MS_EXCEPTION_IF_NULL(instance_manager_); | MS_EXCEPTION_IF_NULL(instance_manager_); | ||||
| instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, node_info_); | |||||
| instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, kvs.second); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1179,7 +1214,7 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler> | |||||
| auto client = GetOrCreateClient(kvs.second); | auto client = GetOrCreateClient(kvs.second); | ||||
| MS_EXCEPTION_IF_NULL(client); | MS_EXCEPTION_IF_NULL(client); | ||||
| MS_EXCEPTION_IF_NULL(instance_manager_); | MS_EXCEPTION_IF_NULL(instance_manager_); | ||||
| instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, node_info_); | |||||
| instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, kvs.second); | |||||
| } | } | ||||
| } | } | ||||
| bool res = Wait(request_id); | bool res = Wait(request_id); | ||||
| @@ -1246,7 +1281,7 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandle | |||||
| auto client = GetOrCreateClient(kvs.second); | auto client = GetOrCreateClient(kvs.second); | ||||
| MS_EXCEPTION_IF_NULL(client); | MS_EXCEPTION_IF_NULL(client); | ||||
| MS_EXCEPTION_IF_NULL(instance_manager_); | MS_EXCEPTION_IF_NULL(instance_manager_); | ||||
| instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, node_info_); | |||||
| instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, kvs.second); | |||||
| } | } | ||||
| } | } | ||||
| bool res = Wait(request_id); | bool res = Wait(request_id); | ||||
| @@ -1308,7 +1343,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> & | |||||
| auto client = GetOrCreateClient(kvs.second); | auto client = GetOrCreateClient(kvs.second); | ||||
| MS_EXCEPTION_IF_NULL(client); | MS_EXCEPTION_IF_NULL(client); | ||||
| MS_EXCEPTION_IF_NULL(instance_manager_); | MS_EXCEPTION_IF_NULL(instance_manager_); | ||||
| instance_manager_->EnableFLSAsync(client, node_manager_, request_id, node_info_); | |||||
| instance_manager_->EnableFLSAsync(client, node_manager_, request_id, kvs.second); | |||||
| } | } | ||||
| } | } | ||||
| bool res = Wait(request_id); | bool res = Wait(request_id); | ||||
| @@ -1334,6 +1369,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> & | |||||
| js["code"] = kSuccessCode; | js["code"] = kSuccessCode; | ||||
| js["result"] = true; | js["result"] = true; | ||||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); | ||||
| PersistMetaData(); | |||||
| } else { | } else { | ||||
| js["message"] = "start enabling FL-Server failed."; | js["message"] = "start enabling FL-Server failed."; | ||||
| js["code"] = kErrorCode; | js["code"] = kErrorCode; | ||||
| @@ -1379,7 +1415,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> | |||||
| auto client = GetOrCreateClient(kvs.second); | auto client = GetOrCreateClient(kvs.second); | ||||
| MS_EXCEPTION_IF_NULL(client); | MS_EXCEPTION_IF_NULL(client); | ||||
| MS_EXCEPTION_IF_NULL(instance_manager_); | MS_EXCEPTION_IF_NULL(instance_manager_); | ||||
| instance_manager_->DisableFLSAsync(client, node_manager_, request_id, node_info_); | |||||
| instance_manager_->DisableFLSAsync(client, node_manager_, request_id, kvs.second); | |||||
| } | } | ||||
| } | } | ||||
| bool res = Wait(request_id); | bool res = Wait(request_id); | ||||
| @@ -1404,6 +1440,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> | |||||
| js["code"] = kSuccessCode; | js["code"] = kSuccessCode; | ||||
| js["result"] = true; | js["result"] = true; | ||||
| node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS); | node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS); | ||||
| PersistMetaData(); | |||||
| } else { | } else { | ||||
| js["message"] = "start disabling FL-Server failed."; | js["message"] = "start disabling FL-Server failed."; | ||||
| js["code"] = kErrorCode; | js["code"] = kErrorCode; | ||||
| @@ -1565,6 +1602,7 @@ void SchedulerNode::PersistMetaData() { | |||||
| return; | return; | ||||
| } | } | ||||
| if (!is_ready_) { | if (!is_ready_) { | ||||
| MS_LOG(WARNING) << "Cluster is not building network successful, do not persist meta data"; | |||||
| return; | return; | ||||
| } | } | ||||
| if (config_->Exists(kKeyRecovery)) { | if (config_->Exists(kKeyRecovery)) { | ||||
| @@ -1576,6 +1614,7 @@ void SchedulerNode::PersistMetaData() { | |||||
| clusterConfig.initial_next_server_rank_id = node_manager_.next_server_rank_id(); | clusterConfig.initial_next_server_rank_id = node_manager_.next_server_rank_id(); | ||||
| clusterConfig.initial_registered_nodes_infos.clear(); | clusterConfig.initial_registered_nodes_infos.clear(); | ||||
| clusterConfig.initial_registered_nodes_infos = node_manager_.registered_nodes_info(); | clusterConfig.initial_registered_nodes_infos = node_manager_.registered_nodes_info(); | ||||
| clusterConfig.initial_cluster_state = node_manager_.GetClusterState(); | |||||
| scheduler_recovery_->Persist(clusterConfig); | scheduler_recovery_->Persist(clusterConfig); | ||||
| scheduler_recovery_->PersistNodesInfo(clusterConfig); | scheduler_recovery_->PersistNodesInfo(clusterConfig); | ||||
| @@ -217,6 +217,8 @@ class BACKEND_EXPORT SchedulerNode : public Node { | |||||
| void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | ||||
| const std::shared_ptr<MessageMeta> &meta, bool is_success, const std::string &error); | const std::shared_ptr<MessageMeta> &meta, bool is_success, const std::string &error); | ||||
| bool BuildingNetwork(); | |||||
| std::shared_ptr<TcpServer> server_; | std::shared_ptr<TcpServer> server_; | ||||
| std::unique_ptr<std::thread> scheduler_thread_; | std::unique_ptr<std::thread> scheduler_thread_; | ||||
| std::unique_ptr<std::thread> update_state_thread_; | std::unique_ptr<std::thread> update_state_thread_; | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "ps/core/scheduler_recovery.h" | #include "ps/core/scheduler_recovery.h" | ||||
| #include "ps/core/comm_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -63,11 +64,6 @@ bool SchedulerRecovery::Recover() { | |||||
| MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path(); | MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path(); | ||||
| } | } | ||||
| MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num | |||||
| << ", the server num:" << clusterConfig.initial_server_num | |||||
| << ", the scheduler ip:" << clusterConfig.scheduler_host | |||||
| << ", the scheduler port:" << clusterConfig.scheduler_port; | |||||
| MS_ERROR_IF_NULL_W_RET_VAL(scheduler_recovery_storage_, false); | MS_ERROR_IF_NULL_W_RET_VAL(scheduler_recovery_storage_, false); | ||||
| // 5. recover total node num | // 5. recover total node num | ||||
| if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) { | if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) { | ||||
| @@ -110,7 +106,6 @@ bool SchedulerRecovery::Recover() { | |||||
| node_info.port_ = static_cast<uint16_t>(std::strtol(port.c_str(), nullptr, kBase)); | node_info.port_ = static_cast<uint16_t>(std::strtol(port.c_str(), nullptr, kBase)); | ||||
| node_info.node_id_ = elem.at("node_id"); | node_info.node_id_ = elem.at("node_id"); | ||||
| node_info.rank_id_ = UlongToUint(std::strtoul(rank_id.c_str(), nullptr, kBase)); | node_info.rank_id_ = UlongToUint(std::strtoul(rank_id.c_str(), nullptr, kBase)); | ||||
| node_info.is_alive = CommUtil::StringToBool(elem.at("alive")); | |||||
| node_info.node_role_ = CommUtil::StringToNodeRole(elem.at("role")); | node_info.node_role_ = CommUtil::StringToNodeRole(elem.at("role")); | ||||
| nodes_infos[node_info.node_id_] = node_info; | nodes_infos[node_info.node_id_] = node_info; | ||||
| @@ -127,18 +122,11 @@ bool SchedulerRecovery::Recover() { | |||||
| MS_LOG(EXCEPTION) << kRecoveryRegisteredNodesInfos << " is not contained in " << recovery_storage_->file_path(); | MS_LOG(EXCEPTION) << kRecoveryRegisteredNodesInfos << " is not contained in " << recovery_storage_->file_path(); | ||||
| } | } | ||||
| MS_LOG(INFO) << ", the initial total node num:" << clusterConfig.initial_total_node_num | |||||
| << ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id | |||||
| << ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id; | |||||
| if (!clusterConfig.initial_registered_nodes_infos.empty()) { | |||||
| for (const auto kvs : clusterConfig.initial_registered_nodes_infos) { | |||||
| MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_ | |||||
| << ", the node_id:" << kvs.second.node_id_ | |||||
| << ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_) | |||||
| << ", the rank_id_:" << kvs.second.rank_id_ | |||||
| << ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive); | |||||
| } | |||||
| // 9. recover cluster state | |||||
| if (recovery_storage_->Exists(kRecoveryClusterState)) { | |||||
| clusterConfig.initial_cluster_state = kClusterStateMap.at(recovery_storage_->GetString(kRecoveryClusterState, "")); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << kRecoveryClusterState << " is not contained in " << recovery_storage_->file_path(); | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -209,8 +209,9 @@ void PSContext::set_rank_id(uint32_t rank_id) const { | |||||
| } | } | ||||
| void PSContext::set_server_mode(const std::string &server_mode) { | void PSContext::set_server_mode(const std::string &server_mode) { | ||||
| if (server_mode != kServerModePS && server_mode != kServerModeFL) { | |||||
| MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL; | |||||
| if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) { | |||||
| MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL | |||||
| << " or " << kServerModeHybrid; | |||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it."; | MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it."; | ||||