| @@ -292,27 +292,9 @@ constexpr auto kNetworkError = "Cluster networking failed."; | |||
| enum class ResultCode { | |||
| // If the method is successfully called and round kernel's residual methods should be called, return kSuccess. | |||
| kSuccess = 0, | |||
| // If there's error happened in the method and residual methods should not be called but this iteration continues, | |||
| // return kSuccessAndReturn so that framework won't drop this iteration. | |||
| kSuccessAndReturn, | |||
| // If there's error happened and this iteration should be dropped, return kFail. | |||
| // If there's error happened, return kFail. | |||
| kFail | |||
| }; | |||
| bool inline ConvertResultCode(ResultCode result_code) { | |||
| switch (result_code) { | |||
| case ResultCode::kSuccess: | |||
| return true; | |||
| case ResultCode::kSuccessAndReturn: | |||
| return true; | |||
| case ResultCode::kFail: | |||
| return false; | |||
| default: | |||
| return true; | |||
| } | |||
| } | |||
| // Definitions for Parameter Server. | |||
| } // namespace server | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| @@ -142,7 +142,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| return true; | |||
| } | |||
| bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||
| bool DistributedCountService::CountReachThreshold(const std::string &name, const std::string &fl_id) { | |||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; | |||
| if (local_rank_ == counting_server_rank_) { | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| @@ -162,8 +162,9 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||
| std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_, | |||
| ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { | |||
| MS_LOG(WARNING) << "Sending querying whether count reaches threshold message to leader server failed for " | |||
| << name; | |||
| std::string reason = "Sending querying whether count reaches " + name + | |||
| " threshold message to leader server failed" + (fl_id.empty() ? "" : " for fl id " + fl_id); | |||
| MS_LOG(WARNING) << reason; | |||
| return false; | |||
| } | |||
| @@ -73,7 +73,7 @@ class DistributedCountService { | |||
| // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, | |||
| // this method returns true. | |||
| bool CountReachThreshold(const std::string &name); | |||
| bool CountReachThreshold(const std::string &name, const std::string &fl_id = ""); | |||
| // Reset the count of the name to 0. | |||
| void ResetCounter(const std::string &name); | |||
| @@ -55,7 +55,10 @@ bool PushMetricsKernel::Launch(const uint8_t *req_data, size_t len, | |||
| ResultCode result_code = PushMetrics(fbb, push_metrics_req); | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool PushMetricsKernel::Reset() { | |||
| @@ -74,8 +77,8 @@ void PushMetricsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::Message | |||
| ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestPushMetrics *push_metrics_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(push_metrics_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(push_metrics_req, ResultCode::kFail); | |||
| float loss = push_metrics_req->loss(); | |||
| float accuracy = push_metrics_req->accuracy(); | |||
| @@ -87,7 +90,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb, | |||
| std::string reason = "Count for push metrics request failed."; | |||
| BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError); | |||
| MS_LOG(ERROR) << reason; | |||
| return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| BuildPushMetricsRsp(fbb, schema::ResponseCode_SUCCEED); | |||
| @@ -60,7 +60,10 @@ bool PushWeightKernel::Launch(const uint8_t *req_data, size_t len, | |||
| ResultCode result_code = PushWeight(fbb, push_weight_req); | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool PushWeightKernel::Reset() { | |||
| @@ -79,8 +82,8 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH | |||
| ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestPushWeight *push_weight_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kFail); | |||
| size_t iteration = IntToSize(push_weight_req->iteration()); | |||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| if (iteration != current_iter) { | |||
| @@ -88,7 +91,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| ", current iteration:" + std::to_string(current_iter); | |||
| BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req); | |||
| @@ -96,14 +99,14 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| std::string reason = "PushWeight feature_map is empty."; | |||
| BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter); | |||
| MS_LOG(ERROR) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| if (!executor_->HandlePushWeight(upload_feature_map)) { | |||
| std::string reason = "Pushing weight failed."; | |||
| BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter); | |||
| MS_LOG(ERROR) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds."; | |||
| @@ -112,7 +115,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| std::string reason = "Count for push weight request failed."; | |||
| BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); | |||
| MS_LOG(ERROR) << reason; | |||
| return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter); | |||
| return ResultCode::kSuccess; | |||
| @@ -25,7 +25,7 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| void ReconstructSecretsKernel::InitKernel(size_t) { | |||
| void ReconstructSecretsKernel::InitKernel(size_t required_cnt) { | |||
| if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { | |||
| iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); | |||
| } | |||
| @@ -41,7 +41,7 @@ void ReconstructSecretsKernel::InitKernel(size_t) { | |||
| name_unmask_ = "UnMaskKernel"; | |||
| MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : " | |||
| << LocalMetaStore::GetInstance().curr_iter_num(); | |||
| DistributedCountService::GetInstance().RegisterCounter(name_unmask_, ps::PSContext::instance()->initial_server_num(), | |||
| DistributedCountService::GetInstance().RegisterCounter(name_unmask_, required_cnt, | |||
| {first_cnt_handler, last_cnt_handler}); | |||
| } | |||
| @@ -59,7 +59,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| std::string reason = "FBBuilder builder or req_data is nullptr."; | |||
| MS_LOG(WARNING) << reason; | |||
| SendResponseMsg(message, reason.c_str(), reason.size()); | |||
| return true; | |||
| return false; | |||
| } | |||
| flatbuffers::Verifier verifier(req_data, len); | |||
| @@ -68,13 +68,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| ResultCode result_code = ReachThresholdForStartFLJob(fbb); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| return false; | |||
| } | |||
| const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data); | |||
| @@ -85,17 +79,23 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| SendResponseMsg(message, reason.c_str(), reason.size()); | |||
| return true; | |||
| return false; | |||
| } | |||
| ResultCode result_code = ReachThresholdForStartFLJob(fbb, start_fl_job_req); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return false; | |||
| } | |||
| if (ps::PSContext::instance()->pki_verify()) { | |||
| if (!JudgeFLJobCert(fbb, start_fl_job_req)) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| return false; | |||
| } | |||
| if (!StoreKeyAttestation(fbb, start_fl_job_req)) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -103,25 +103,25 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| result_code = ReadyForStartFLJob(fbb, device_meta); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| return false; | |||
| } | |||
| PBMetadata metadata; | |||
| *metadata.mutable_device_meta() = device_meta; | |||
| std::string update_reason = ""; | |||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) { | |||
| std::string reason = "Updating device metadata failed. " + update_reason; | |||
| std::string reason = "Updating device metadata failed for fl id " + device_meta.fl_id(); | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return update_reason == kNetworkError ? false : true; | |||
| return false; | |||
| } | |||
| // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. | |||
| result_code = CountForStartFLJob(fbb, start_fl_job_req); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| return false; | |||
| } | |||
| IncreaseAcceptClientNum(); | |||
| auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| @@ -133,7 +133,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| fbb->GetBufferPointer(), fbb->GetSize()); | |||
| if (cache == nullptr) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| return false; | |||
| } | |||
| } | |||
| SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache); | |||
| @@ -237,14 +237,17 @@ void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::Message | |||
| Iteration::GetInstance().SetIterationRunning(); | |||
| } | |||
| ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) { | |||
| if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { | |||
| ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestFLJob *start_fl_job_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail); | |||
| if (DistributedCountService::GetInstance().CountReachThreshold(name_, start_fl_job_req->fl_id()->str())) { | |||
| std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later."; | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(DEBUG) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| } | |||
| @@ -271,7 +274,7 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> | |||
| std::string reason = ""; | |||
| if (device_meta.data_size() < 1) { | |||
| reason = "FL job data size is not enough."; | |||
| ret = ResultCode::kSuccessAndReturn; | |||
| ret = ResultCode::kFail; | |||
| } | |||
| if (ret != ResultCode::kSuccess) { | |||
| BuildStartFLJobRsp( | |||
| @@ -284,17 +287,18 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> | |||
| ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestFLJob *start_fl_job_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail); | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) { | |||
| std::string reason = "Counting start fl job request failed. Please retry later."; | |||
| std::string reason = | |||
| "Counting start fl job request failed for fl id " + start_fl_job_req->fl_id()->str() + ", Please retry later."; | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| } | |||
| @@ -43,7 +43,8 @@ class StartFLJobKernel : public RoundKernel { | |||
| private: | |||
| // Returns whether the startFLJob count of this iteration has reached the threshold. | |||
| ResultCode ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb); | |||
| ResultCode ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestFLJob *start_fl_job_req); | |||
| // The metadata of device will be stored and queried in updateModel round. | |||
| DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req); | |||
| @@ -60,7 +60,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| std::string reason = "FBBuilder builder or req_data is nullptr."; | |||
| MS_LOG(WARNING) << reason; | |||
| SendResponseMsg(message, reason.c_str(), reason.size()); | |||
| return true; | |||
| return false; | |||
| } | |||
| flatbuffers::Verifier verifier(req_data, len); | |||
| @@ -69,13 +69,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| ResultCode result_code = ReachThresholdForUpdateModel(fbb); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| return false; | |||
| } | |||
| const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data); | |||
| @@ -87,44 +81,30 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| return true; | |||
| } | |||
| // verify signature | |||
| if (ps::PSContext::instance()->pki_verify()) { | |||
| sigVerifyResult verify_result = VerifySignature(update_model_req); | |||
| if (verify_result == sigVerifyResult::FAILED) { | |||
| std::string reason = "verify signature failed."; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| if (verify_result == sigVerifyResult::TIMEOUT) { | |||
| std::string reason = "verify signature timestamp failed."; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| MS_LOG(INFO) << "verify signature passed!"; | |||
| ResultCode result_code = ReachThresholdForUpdateModel(fbb, update_model_req); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return false; | |||
| } | |||
| DeviceMeta device_meta; | |||
| result_code = VerifyUpdateModel(update_model_req, fbb, &device_meta); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| MS_LOG(WARNING) << "Updating model failed."; | |||
| MS_LOG(DEBUG) << "Verify updating model failed."; | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| return false; | |||
| } | |||
| result_code = CountForUpdateModel(fbb, update_model_req); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| return false; | |||
| } | |||
| result_code = UpdateModel(update_model_req, fbb, device_meta); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| MS_LOG(WARNING) << "Updating model failed."; | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| MS_LOG(DEBUG) << "Updating model failed."; | |||
| return false; | |||
| } | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| @@ -169,41 +149,65 @@ void UpdateModelKernel::RunAggregation() { | |||
| } | |||
| } | |||
| ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) { | |||
| if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { | |||
| ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kFail); | |||
| if (DistributedCountService::GetInstance().CountReachThreshold(name_, update_model_req->fl_id()->str())) { | |||
| std::string reason = "Current amount for updateModel is enough. Please retry later."; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| } | |||
| ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(device_meta, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(device_meta, ResultCode::kFail); | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| size_t iteration = IntToSize(update_model_req->iteration()); | |||
| if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { | |||
| auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp); | |||
| std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + | |||
| ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) + | |||
| ". Retry later at time: " + std::to_string(next_req_time); | |||
| ", Retry later at time: " + std::to_string(next_req_time) + ", fl id is " + update_model_fl_id; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time)); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| // verify signature | |||
| if (ps::PSContext::instance()->pki_verify()) { | |||
| sigVerifyResult verify_result = VerifySignature(update_model_req); | |||
| if (verify_result == sigVerifyResult::FAILED) { | |||
| std::string reason = "verify signature failed for fl id " + update_model_fl_id; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kFail; | |||
| } | |||
| if (verify_result == sigVerifyResult::TIMEOUT) { | |||
| std::string reason = "verify signature timestamp failed for fl id " + update_model_fl_id; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kFail; | |||
| } | |||
| MS_LOG(DEBUG) << "verify signature passed!"; | |||
| } | |||
| std::unordered_map<std::string, size_t> feature_map; | |||
| auto upload_feature_map = update_model_req->feature_map(); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail); | |||
| for (uint32_t i = 0; i < upload_feature_map->size(); i++) { | |||
| const auto &item = upload_feature_map->Get(i); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail); | |||
| std::string weight_full_name = item->weight_fullname()->str(); | |||
| size_t weight_size = item->data()->size() * sizeof(float); | |||
| @@ -212,7 +216,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel | |||
| bool verifyFeatureMapIsSuccess; | |||
| if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail); | |||
| verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req); | |||
| } else { | |||
| verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map); | |||
| @@ -222,10 +226,9 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel | |||
| std::string reason = "Verify model feature map failed, retry later at time: " + std::to_string(next_req_time); | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time)); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| MS_LOG(DEBUG) << "UpdateModel for fl id " << update_model_fl_id; | |||
| bool found = DistributedMetadataStore::GetInstance().GetOneDeviceMeta(update_model_fl_id, device_meta); | |||
| @@ -235,7 +238,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) { | |||
| std::vector<std::string> get_secrets_clients; | |||
| @@ -250,7 +253,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| } | |||
| return ResultCode::kSuccess; | |||
| @@ -279,8 +282,9 @@ bool UpdateModelKernel::VerifySignDSFeatureMap(const std::unordered_map<std::str | |||
| ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kFail); | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| size_t data_size = device_meta.data_size(); | |||
| @@ -293,17 +297,17 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| } | |||
| if (feature_map.empty()) { | |||
| std::string reason = "Feature map is empty."; | |||
| std::string reason = "Feature map is empty for fl id " + update_model_fl_id; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| for (auto weight : feature_map) { | |||
| weight.second[kNewDataSize].addr = &data_size; | |||
| weight.second[kNewDataSize].size = sizeof(size_t); | |||
| if (!executor_->HandleModelUpdate(weight.first, weight.second)) { | |||
| std::string reason = "Updating weight " + weight.first + " failed."; | |||
| std::string reason = "Updating weight " + weight.first + " failed for fl id " + update_model_fl_id; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| @@ -318,12 +322,12 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| *comm_value.mutable_fl_id() = fl_id; | |||
| std::string update_reason = ""; | |||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) { | |||
| std::string reason = "Updating metadata of UpdateModelClientList failed. " + update_reason; | |||
| std::string reason = "Updating metadata of UpdateModelClientList failed for fl id " + update_model_fl_id; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| return update_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| UpdateClientUploadLoss(update_model_req->upload_loss()); | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready", | |||
| @@ -404,16 +408,17 @@ ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) | |||
| ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) { | |||
| std::string reason = "Counting for update model request failed. Please retry later. " + count_reason; | |||
| std::string reason = "Counting for update model request failed for fl id " + update_model_req->fl_id()->str() + | |||
| ", Please retry later. " + count_reason; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn; | |||
| return ResultCode::kFail; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| } | |||
| @@ -53,7 +53,8 @@ class UpdateModelKernel : public RoundKernel { | |||
| void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override; | |||
| private: | |||
| ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb); | |||
| ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req); | |||
| ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb, | |||
| const DeviceMeta &device_meta); | |||
| std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); | |||
| @@ -19,7 +19,6 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "fl/server/executor.h" | |||
| #include "pipeline/jit/parse/parse.h" | |||
| #include "include/common/utils/python_adapter.h" | |||
| namespace mindspore { | |||
| @@ -86,7 +86,11 @@ bool Round::ReInitForScaling(uint32_t server_num) { | |||
| } | |||
| MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false); | |||
| kernel_->InitKernel(threshold_count_); | |||
| if (name_ == "reconstructSecrets") { | |||
| kernel_->InitKernel(server_num); | |||
| } else { | |||
| kernel_->InitKernel(threshold_count_); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -128,8 +132,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m | |||
| bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message); | |||
| // Must send response back no matter what value Launch method returns. | |||
| if (!ret) { | |||
| reason = "Launching round kernel of round " + name_ + " failed."; | |||
| Iteration::GetInstance().NotifyNext(false, reason); | |||
| MS_LOG(DEBUG) << "Launching round kernel of round " + name_ + " failed."; | |||
| } | |||
| (void)(Iteration::GetInstance().running_round_num_--); | |||
| return; | |||
| @@ -431,7 +431,11 @@ void Server::RegisterRoundKernel() { | |||
| } | |||
| // For some round kernels, the threshold count should be set. | |||
| round_kernel->InitKernel(round->threshold_count()); | |||
| if (name == "reconstructSecrets") { | |||
| round_kernel->InitKernel(server_node_->server_num()); | |||
| } else { | |||
| round_kernel->InitKernel(round->threshold_count()); | |||
| } | |||
| round->BindRoundKernel(round_kernel); | |||
| } | |||
| return; | |||
| @@ -41,8 +41,8 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| } | |||
| if (PSContext::instance()->scheduler_manage_port() != 0) { | |||
| MS_LOG(WARNING) << "Start the restful scheduler http service, the ip is 127.0.0.1 " | |||
| << ", the port:" << PSContext::instance()->scheduler_manage_port(); | |||
| MS_LOG(INFO) << "Start the restful scheduler http service, the ip is 127.0.0.1 " | |||
| << ", the port:" << PSContext::instance()->scheduler_manage_port(); | |||
| StartRestfulServer(kLocalIp, PSContext::instance()->scheduler_manage_port(), 1); | |||
| } | |||
| Initialize(); | |||