Browse Source

fix I4X31J and I4VPZ5 and I4WSW2

r1.7
twc 4 years ago
parent
commit
08fc5c001b
14 changed files with 133 additions and 127 deletions
  1. +1
    -19
      mindspore/ccsrc/fl/server/common.h
  2. +4
    -3
      mindspore/ccsrc/fl/server/distributed_count_service.cc
  3. +1
    -1
      mindspore/ccsrc/fl/server/distributed_count_service.h
  4. +7
    -4
      mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc
  5. +10
    -7
      mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc
  6. +2
    -2
      mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc
  7. +28
    -24
      mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc
  8. +2
    -1
      mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h
  9. +63
    -58
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc
  10. +2
    -1
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h
  11. +0
    -1
      mindspore/ccsrc/fl/server/model_store.cc
  12. +6
    -3
      mindspore/ccsrc/fl/server/round.cc
  13. +5
    -1
      mindspore/ccsrc/fl/server/server.cc
  14. +2
    -2
      mindspore/ccsrc/ps/core/scheduler_node.cc

+ 1
- 19
mindspore/ccsrc/fl/server/common.h View File

@@ -292,27 +292,9 @@ constexpr auto kNetworkError = "Cluster networking failed.";
enum class ResultCode {
// If the method is successfully called and round kernel's residual methods should be called, return kSuccess.
kSuccess = 0,
// If there's error happened in the method and residual methods should not be called but this iteration continues,
// return kSuccessAndReturn so that framework won't drop this iteration.
kSuccessAndReturn,
// If there's error happened and this iteration should be dropped, return kFail.
// If there's error happened, return kFail.
kFail
};

bool inline ConvertResultCode(ResultCode result_code) {
switch (result_code) {
case ResultCode::kSuccess:
return true;
case ResultCode::kSuccessAndReturn:
return true;
case ResultCode::kFail:
return false;
default:
return true;
}
}

// Definitions for Parameter Server.
} // namespace server
} // namespace fl
} // namespace mindspore


+ 4
- 3
mindspore/ccsrc/fl/server/distributed_count_service.cc View File

@@ -142,7 +142,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
return true;
}

bool DistributedCountService::CountReachThreshold(const std::string &name) {
bool DistributedCountService::CountReachThreshold(const std::string &name, const std::string &fl_id) {
MS_LOG(DEBUG) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) {
@@ -162,8 +162,9 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
MS_LOG(WARNING) << "Sending querying whether count reaches threshold message to leader server failed for "
<< name;
std::string reason = "Sending querying whether count reaches " + name +
" threshold message to leader server failed" + (fl_id.empty() ? "" : " for fl id " + fl_id);
MS_LOG(WARNING) << reason;
return false;
}



+ 1
- 1
mindspore/ccsrc/fl/server/distributed_count_service.h View File

@@ -73,7 +73,7 @@ class DistributedCountService {

// Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
// this method returns true.
bool CountReachThreshold(const std::string &name);
bool CountReachThreshold(const std::string &name, const std::string &fl_id = "");

// Reset the count of the name to 0.
void ResetCounter(const std::string &name);


+ 7
- 4
mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc View File

@@ -55,7 +55,10 @@ bool PushMetricsKernel::Launch(const uint8_t *req_data, size_t len,

ResultCode result_code = PushMetrics(fbb, push_metrics_req);
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
if (result_code != ResultCode::kSuccess) {
return false;
}
return true;
}

bool PushMetricsKernel::Reset() {
@@ -74,8 +77,8 @@ void PushMetricsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::Message

ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestPushMetrics *push_metrics_req) {
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(push_metrics_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(push_metrics_req, ResultCode::kFail);

float loss = push_metrics_req->loss();
float accuracy = push_metrics_req->accuracy();
@@ -87,7 +90,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb,
std::string reason = "Count for push metrics request failed.";
BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError);
MS_LOG(ERROR) << reason;
return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}

BuildPushMetricsRsp(fbb, schema::ResponseCode_SUCCEED);


+ 10
- 7
mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc View File

@@ -60,7 +60,10 @@ bool PushWeightKernel::Launch(const uint8_t *req_data, size_t len,

ResultCode result_code = PushWeight(fbb, push_weight_req);
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
if (result_code != ResultCode::kSuccess) {
return false;
}
return true;
}

bool PushWeightKernel::Reset() {
@@ -79,8 +82,8 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH

ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestPushWeight *push_weight_req) {
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kFail);
size_t iteration = IntToSize(push_weight_req->iteration());
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
if (iteration != current_iter) {
@@ -88,7 +91,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
", current iteration:" + std::to_string(current_iter);
BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}

std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
@@ -96,14 +99,14 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
std::string reason = "PushWeight feature_map is empty.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}

if (!executor_->HandlePushWeight(upload_feature_map)) {
std::string reason = "Pushing weight failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";

@@ -112,7 +115,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
std::string reason = "Count for push weight request failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
MS_LOG(ERROR) << reason;
return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
return ResultCode::kSuccess;


+ 2
- 2
mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc View File

@@ -25,7 +25,7 @@ namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
void ReconstructSecretsKernel::InitKernel(size_t) {
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
@@ -41,7 +41,7 @@ void ReconstructSecretsKernel::InitKernel(size_t) {
name_unmask_ = "UnMaskKernel";
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
<< LocalMetaStore::GetInstance().curr_iter_num();
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, ps::PSContext::instance()->initial_server_num(),
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, required_cnt,
{first_cnt_handler, last_cnt_handler});
}



+ 28
- 24
mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc View File

@@ -59,7 +59,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(WARNING) << reason;
SendResponseMsg(message, reason.c_str(), reason.size());
return true;
return false;
}

flatbuffers::Verifier verifier(req_data, len);
@@ -68,13 +68,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, "");
MS_LOG(WARNING) << reason;
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

ResultCode result_code = ReachThresholdForStartFLJob(fbb);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
return false;
}

const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
@@ -85,17 +79,23 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
SendResponseMsg(message, reason.c_str(), reason.size());
return true;
return false;
}

ResultCode result_code = ReachThresholdForStartFLJob(fbb, start_fl_job_req);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}

if (ps::PSContext::instance()->pki_verify()) {
if (!JudgeFLJobCert(fbb, start_fl_job_req)) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
return false;
}
if (!StoreKeyAttestation(fbb, start_fl_job_req)) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
return false;
}
}

@@ -103,25 +103,25 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
result_code = ReadyForStartFLJob(fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
return false;
}
PBMetadata metadata;
*metadata.mutable_device_meta() = device_meta;
std::string update_reason = "";
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) {
std::string reason = "Updating device metadata failed. " + update_reason;
std::string reason = "Updating device metadata failed for fl id " + device_meta.fl_id();
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return update_reason == kNetworkError ? false : true;
return false;
}

// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
result_code = CountForStartFLJob(fbb, start_fl_job_req);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
return false;
}
IncreaseAcceptClientNum();
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
@@ -133,7 +133,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
fbb->GetBufferPointer(), fbb->GetSize());
if (cache == nullptr) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
return false;
}
}
SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache);
@@ -237,14 +237,17 @@ void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::Message
Iteration::GetInstance().SetIterationRunning();
}

ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail);
if (DistributedCountService::GetInstance().CountReachThreshold(name_, start_fl_job_req->fl_id()->str())) {
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(DEBUG) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
return ResultCode::kSuccess;
}
@@ -271,7 +274,7 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder>
std::string reason = "";
if (device_meta.data_size() < 1) {
reason = "FL job data size is not enough.";
ret = ResultCode::kSuccessAndReturn;
ret = ResultCode::kFail;
}
if (ret != ResultCode::kSuccess) {
BuildStartFLJobRsp(
@@ -284,17 +287,18 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder>

ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail);

std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) {
std::string reason = "Counting start fl job request failed. Please retry later.";
std::string reason =
"Counting start fl job request failed for fl id " + start_fl_job_req->fl_id()->str() + ", Please retry later.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
return ResultCode::kSuccess;
}


+ 2
- 1
mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h View File

@@ -43,7 +43,8 @@ class StartFLJobKernel : public RoundKernel {

private:
// Returns whether the startFLJob count of this iteration has reached the threshold.
ResultCode ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
ResultCode ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req);

// The metadata of device will be stored and queried in updateModel round.
DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req);


+ 63
- 58
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc View File

@@ -60,7 +60,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(WARNING) << reason;
SendResponseMsg(message, reason.c_str(), reason.size());
return true;
return false;
}

flatbuffers::Verifier verifier(req_data, len);
@@ -69,13 +69,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

ResultCode result_code = ReachThresholdForUpdateModel(fbb);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
return false;
}

const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
@@ -87,44 +81,30 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
return true;
}

// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(update_model_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, "");
MS_LOG(WARNING) << reason;
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
MS_LOG(INFO) << "verify signature passed!";
ResultCode result_code = ReachThresholdForUpdateModel(fbb, update_model_req);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}

DeviceMeta device_meta;
result_code = VerifyUpdateModel(update_model_req, fbb, &device_meta);
if (result_code != ResultCode::kSuccess) {
MS_LOG(WARNING) << "Updating model failed.";
MS_LOG(DEBUG) << "Verify updating model failed.";
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
return false;
}

result_code = CountForUpdateModel(fbb, update_model_req);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
return false;
}

result_code = UpdateModel(update_model_req, fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
MS_LOG(WARNING) << "Updating model failed.";
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
MS_LOG(DEBUG) << "Updating model failed.";
return false;
}
std::string update_model_fl_id = update_model_req->fl_id()->str();
@@ -169,41 +149,65 @@ void UpdateModelKernel::RunAggregation() {
}
}

ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kFail);
if (DistributedCountService::GetInstance().CountReachThreshold(name_, update_model_req->fl_id()->str())) {
std::string reason = "Current amount for updateModel is enough. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
return ResultCode::kSuccess;
}

ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(device_meta, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(device_meta, ResultCode::kFail);

std::string update_model_fl_id = update_model_req->fl_id()->str();
size_t iteration = IntToSize(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
". Retry later at time: " + std::to_string(next_req_time);
", Retry later at time: " + std::to_string(next_req_time) + ", fl id is " + update_model_fl_id;
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time));
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}

// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(update_model_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed for fl id " + update_model_fl_id;
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
return ResultCode::kFail;
}

if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed for fl id " + update_model_fl_id;
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, "");
MS_LOG(WARNING) << reason;
return ResultCode::kFail;
}
MS_LOG(DEBUG) << "verify signature passed!";
}

std::unordered_map<std::string, size_t> feature_map;
auto upload_feature_map = update_model_req->feature_map();
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail);
for (uint32_t i = 0; i < upload_feature_map->size(); i++) {
const auto &item = upload_feature_map->Get(i);
MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail);

std::string weight_full_name = item->weight_fullname()->str();
size_t weight_size = item->data()->size() * sizeof(float);
@@ -212,7 +216,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel

bool verifyFeatureMapIsSuccess;
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail);
verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req);
} else {
verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map);
@@ -222,10 +226,9 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
std::string reason = "Verify model feature map failed, retry later at time: " + std::to_string(next_req_time);
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time));
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}

std::string update_model_fl_id = update_model_req->fl_id()->str();
MS_LOG(DEBUG) << "UpdateModel for fl id " << update_model_fl_id;

bool found = DistributedMetadataStore::GetInstance().GetOneDeviceMeta(update_model_fl_id, device_meta);
@@ -235,7 +238,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
std::vector<std::string> get_secrets_clients;
@@ -250,7 +253,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
}
return ResultCode::kSuccess;
@@ -279,8 +282,9 @@ bool UpdateModelKernel::VerifySignDSFeatureMap(const std::unordered_map<std::str

ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kFail);

std::string update_model_fl_id = update_model_req->fl_id()->str();
size_t data_size = device_meta.data_size();

@@ -293,17 +297,17 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
}

if (feature_map.empty()) {
std::string reason = "Feature map is empty.";
std::string reason = "Feature map is empty for fl id " + update_model_fl_id;
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}

for (auto weight : feature_map) {
weight.second[kNewDataSize].addr = &data_size;
weight.second[kNewDataSize].size = sizeof(size_t);
if (!executor_->HandleModelUpdate(weight.first, weight.second)) {
std::string reason = "Updating weight " + weight.first + " failed.";
std::string reason = "Updating weight " + weight.first + " failed for fl id " + update_model_fl_id;
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
@@ -318,12 +322,12 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
*comm_value.mutable_fl_id() = fl_id;
std::string update_reason = "";
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) {
std::string reason = "Updating metadata of UpdateModelClientList failed. " + update_reason;
std::string reason = "Updating metadata of UpdateModelClientList failed for fl id " + update_model_fl_id;
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return update_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
UpdateClientUploadLoss(update_model_req->upload_loss());
BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
@@ -404,16 +408,17 @@ ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id)

ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) {
std::string reason = "Counting for update model request failed. Please retry later. " + count_reason;
std::string reason = "Counting for update model request failed for fl id " + update_model_req->fl_id()->str() +
", Please retry later. " + count_reason;
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
return ResultCode::kFail;
}
return ResultCode::kSuccess;
}


+ 2
- 1
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h View File

@@ -53,7 +53,8 @@ class UpdateModelKernel : public RoundKernel {
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;

private:
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req);
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb,
const DeviceMeta &device_meta);
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);


+ 0
- 1
mindspore/ccsrc/fl/server/model_store.cc View File

@@ -19,7 +19,6 @@
#include <string>
#include <memory>
#include "fl/server/executor.h"
#include "pipeline/jit/parse/parse.h"
#include "include/common/utils/python_adapter.h"

namespace mindspore {


+ 6
- 3
mindspore/ccsrc/fl/server/round.cc View File

@@ -86,7 +86,11 @@ bool Round::ReInitForScaling(uint32_t server_num) {
}

MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
kernel_->InitKernel(threshold_count_);
if (name_ == "reconstructSecrets") {
kernel_->InitKernel(server_num);
} else {
kernel_->InitKernel(threshold_count_);
}
return true;
}

@@ -128,8 +132,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message);
// Must send response back no matter what value Launch method returns.
if (!ret) {
reason = "Launching round kernel of round " + name_ + " failed.";
Iteration::GetInstance().NotifyNext(false, reason);
MS_LOG(DEBUG) << "Launching round kernel of round " + name_ + " failed.";
}
(void)(Iteration::GetInstance().running_round_num_--);
return;


+ 5
- 1
mindspore/ccsrc/fl/server/server.cc View File

@@ -431,7 +431,11 @@ void Server::RegisterRoundKernel() {
}

// For some round kernels, the threshold count should be set.
round_kernel->InitKernel(round->threshold_count());
if (name == "reconstructSecrets") {
round_kernel->InitKernel(server_node_->server_num());
} else {
round_kernel->InitKernel(round->threshold_count());
}
round->BindRoundKernel(round_kernel);
}
return;


+ 2
- 2
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -41,8 +41,8 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
}

if (PSContext::instance()->scheduler_manage_port() != 0) {
MS_LOG(WARNING) << "Start the restful scheduler http service, the ip is 127.0.0.1 "
<< ", the port:" << PSContext::instance()->scheduler_manage_port();
MS_LOG(INFO) << "Start the restful scheduler http service, the ip is 127.0.0.1 "
<< ", the port:" << PSContext::instance()->scheduler_manage_port();
StartRestfulServer(kLocalIp, PSContext::instance()->scheduler_manage_port(), 1);
}
Initialize();


Loading…
Cancel
Save