Browse Source

fix I4RQOC && sync code

feature/build-system-rewrite
twc 4 years ago
parent
commit
997ae1133a
30 changed files with 213 additions and 172 deletions
  1. +44
    -44
      mindspore/ccsrc/fl/server/cert_verify.cc
  2. +11
    -11
      mindspore/ccsrc/fl/server/collective_ops_impl.cc
  3. +25
    -24
      mindspore/ccsrc/fl/server/distributed_count_service.cc
  4. +26
    -26
      mindspore/ccsrc/fl/server/distributed_metadata_store.cc
  5. +1
    -1
      mindspore/ccsrc/fl/server/iteration.cc
  6. +6
    -0
      mindspore/ccsrc/fl/server/iteration_timer.cc
  7. +2
    -0
      mindspore/ccsrc/fl/server/iteration_timer.h
  8. +3
    -3
      mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h
  9. +6
    -6
      mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc
  10. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.h
  11. +12
    -12
      mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc
  12. +21
    -5
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc
  13. +4
    -2
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h
  14. +2
    -2
      mindspore/ccsrc/fl/server/memory_register.h
  15. +3
    -3
      mindspore/ccsrc/fl/server/model_store.cc
  16. +9
    -3
      mindspore/ccsrc/fl/server/round.cc
  17. +4
    -0
      mindspore/ccsrc/fl/server/server_recovery.cc
  18. +2
    -0
      mindspore/ccsrc/fl/server/server_recovery.h
  19. +3
    -8
      mindspore/ccsrc/ps/core/abstract_node.cc
  20. +2
    -3
      mindspore/ccsrc/ps/core/communicator/tcp_client.cc
  21. +1
    -1
      mindspore/ccsrc/ps/core/communicator/tcp_client.h
  22. +3
    -2
      mindspore/ccsrc/ps/core/communicator/tcp_server.cc
  23. +2
    -2
      mindspore/ccsrc/ps/core/node.cc
  24. +1
    -0
      mindspore/ccsrc/ps/core/node_recovery.cc
  25. +5
    -2
      mindspore/ccsrc/ps/core/recovery_base.cc
  26. +5
    -2
      mindspore/ccsrc/ps/core/recovery_base.h
  27. +4
    -4
      mindspore/ccsrc/ps/core/scheduler_node.cc
  28. +2
    -0
      mindspore/ccsrc/ps/core/scheduler_recovery.cc
  29. +2
    -4
      tests/ut/cpp/ps/core/tcp_client_tests.cc
  30. +1
    -1
      tests/ut/cpp/ps/core/tcp_pb_server_test.cc

+ 44
- 44
mindspore/ccsrc/fl/server/cert_verify.cc View File

@@ -72,11 +72,11 @@ bool CertVerify::verifyCertTime(const X509 *cert) const {
return false;
}
if (day < 0) {
MS_LOG(ERROR) << "cert start day time is later than now day time, day is" << day;
MS_LOG(WARNING) << "cert start day time is later than now day time, day is" << day;
return false;
}
if (day == 0 && sec < certStartTimeDiff) {
MS_LOG(ERROR) << "cert start second time is later than 600 second, second is" << sec;
MS_LOG(WARNING) << "cert start second time is later than 600 second, second is" << sec;
return false;
}
day = 0;
@@ -87,10 +87,10 @@ bool CertVerify::verifyCertTime(const X509 *cert) const {
}

if (day < 0 || sec < 0) {
MS_LOG(ERROR) << "cert end time is sooner than now time.";
MS_LOG(WARNING) << "cert end time is sooner than now time.";
return false;
}
MS_LOG(INFO) << "verify cert time success.";
MS_LOG(WARNING) << "verify cert time end.";
return true;
}

@@ -105,20 +105,20 @@ bool CertVerify::verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *
int ret = 0;
ret = X509_verify(const_cast<X509 *>(keyAttestationCertObj), equipPubKey);
if (ret != 1) {
MS_LOG(ERROR) << "keyAttestationCert verify is failed";
MS_LOG(WARNING) << "keyAttestationCert verify is failed";
result = false;
break;
}
ret = X509_verify(const_cast<X509 *>(equipCertObj), equipCAPubKey);
if (ret != 1) {
MS_LOG(ERROR) << "equip cert verify is failed";
MS_LOG(WARNING) << "equip cert verify is failed";
result = false;
break;
}
int ret_first = X509_verify(const_cast<X509 *>(equipCACertObj), rootFirstPubKey);
int ret_second = X509_verify(const_cast<X509 *>(equipCACertObj), rootSecondPubKey);
if (ret_first != 1 && ret_second != 1) {
MS_LOG(ERROR) << "equip ca cert verify is failed";
MS_LOG(WARNING) << "equip ca cert verify is failed";
result = false;
break;
}
@@ -128,7 +128,7 @@ bool CertVerify::verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *
EVP_PKEY_free(equipCAPubKey);
EVP_PKEY_free(rootFirstPubKey);
EVP_PKEY_free(rootSecondPubKey);
MS_LOG(INFO) << "verify Public Key success.";
MS_LOG(WARNING) << "verify Public Key end.";
return result;
}

@@ -143,7 +143,7 @@ bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::str
bool result = true;
do {
if (rootFirstCA == nullptr || rootSecondCA == nullptr) {
MS_LOG(ERROR) << "rootFirstCA or rootSecondCA is nullptr";
MS_LOG(WARNING) << "rootFirstCA or rootSecondCA is nullptr";
result = false;
break;
}
@@ -158,37 +158,37 @@ bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::str
}

if (!verifyCertCommonName(equipCACertObj, equipCertObj)) {
MS_LOG(ERROR) << "equip ca cert subject cn is not equal with equip cert issuer cn.";
MS_LOG(WARNING) << "equip ca cert subject cn is not equal with equip cert issuer cn.";
result = false;
break;
}

if (!verifyCertCommonName(rootFirstCA, equipCACertObj) && !verifyCertCommonName(rootSecondCA, equipCACertObj)) {
MS_LOG(ERROR) << "root CA cert subject cn is not equal with equip CA cert issuer cn.";
MS_LOG(WARNING) << "root CA cert subject cn is not equal with equip CA cert issuer cn.";
result = false;
break;
}

if (!verifyExtendedAttributes(equipCACertObj)) {
MS_LOG(ERROR) << "verify equipCACert Extended Attributes failed.";
MS_LOG(WARNING) << "verify equipCACert Extended Attributes failed.";
result = false;
break;
}

if (!verifyCertKeyID(rootFirstCA, equipCACertObj) && !verifyCertKeyID(rootSecondCA, equipCACertObj)) {
MS_LOG(ERROR) << "root CA cert subject keyid is not equal with equip CA cert issuer keyid.";
MS_LOG(WARNING) << "root CA cert subject keyid is not equal with equip CA cert issuer keyid.";
result = false;
break;
}

if (!verifyCertKeyID(equipCACertObj, equipCertObj)) {
MS_LOG(ERROR) << "equip CA cert subject keyid is not equal with equip cert issuer keyid.";
MS_LOG(WARNING) << "equip CA cert subject keyid is not equal with equip cert issuer keyid.";
result = false;
break;
}

if (!verifyPublicKey(keyAttestationCertObj, equipCertObj, equipCACertObj, rootFirstCA, rootSecondCA)) {
MS_LOG(ERROR) << "verify Public Key failed";
MS_LOG(WARNING) << "verify Public Key failed";
result = false;
break;
}
@@ -198,7 +198,7 @@ bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::str
X509_free(keyAttestationCertObj);
X509_free(equipCertObj);
X509_free(equipCACertObj);
MS_LOG(INFO) << "verifyCAChain success.";
MS_LOG(WARNING) << "verifyCAChain end.";
return result;
}

@@ -232,7 +232,7 @@ bool CertVerify::verifyCertKeyID(const X509 *caCert, const X509 *subCert) const
}
char issuer_keyid[512] = {0};
if (akeyid->keyid == nullptr) {
MS_LOG(ERROR) << "keyid is nullprt.";
MS_LOG(WARNING) << "keyid is nullprt.";
result = false;
break;
}
@@ -271,11 +271,11 @@ bool CertVerify::verifyExtendedAttributes(const X509 *cert) const {
break;
}
if (!bcons->ca) {
MS_LOG(ERROR) << "Subject Type is End Entity.";
MS_LOG(WARNING) << "Subject Type is End Entity.";
result = false;
break;
}
MS_LOG(INFO) << "Subject Type is CA.";
MS_LOG(WARNING) << "Subject Type is CA.";

lASN1UsageStr = reinterpret_cast<ASN1_BIT_STRING *>(X509_get_ext_d2i(cert, NID_key_usage, NULL, NULL));
if (lASN1UsageStr == nullptr) {
@@ -289,11 +289,11 @@ bool CertVerify::verifyExtendedAttributes(const X509 *cert) const {
}

if (!(usage & KU_KEY_CERT_SIGN)) {
MS_LOG(ERROR) << "Subject is not Certificate Signature.";
MS_LOG(WARNING) << "Subject is not Certificate Signature.";
result = false;
break;
}
MS_LOG(INFO) << "Subject is Certificate Signature.";
MS_LOG(WARNING) << "Subject is Certificate Signature.";
} while (0);
BASIC_CONSTRAINTS_free(bcons);
ASN1_BIT_STRING_free(lASN1UsageStr);
@@ -346,14 +346,14 @@ bool CertVerify::verifyCRL(const std::string &equipCert, const std::string &equi
}

if (equipCrl == nullptr) {
MS_LOG(INFO) << "equipCrl is nullptr. return true.";
MS_LOG(WARNING) << "equipCrl is nullptr. return true.";
result = true;
break;
}
evp_pkey = X509_get_pubkey(equipCertObj);
int ret = X509_CRL_verify(equipCrl, evp_pkey);
if (ret == 1) {
MS_LOG(ERROR) << "equip cert in equip crl, verify failed";
MS_LOG(WARNING) << "equip cert in equip crl, verify failed";
result = false;
break;
}
@@ -362,14 +362,14 @@ bool CertVerify::verifyCRL(const std::string &equipCert, const std::string &equi
EVP_PKEY_free(evp_pkey);
X509_free(equipCertObj);
X509_CRL_free(equipCrl);
MS_LOG(INFO) << "verifyCRL success.";
MS_LOG(WARNING) << "verifyCRL end.";
return result;
}

bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID,
const std::string &timeStamp) {
if (keyAttestation.empty() || signData == nullptr || flID.empty() || timeStamp.empty()) {
MS_LOG(ERROR) << "keyAttestation or signData or flID or timeStamp is empty.";
MS_LOG(WARNING) << "keyAttestation or signData or flID or timeStamp is empty.";
return false;
}
bool result = true;
@@ -386,7 +386,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned
pubKey = X509_get_pubkey(keyAttestationCertObj);
RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey);
if (pRSAPublicKey == nullptr) {
MS_LOG(ERROR) << "get rsa public key failed.";
MS_LOG(WARNING) << "get rsa public key failed.";
result = false;
break;
}
@@ -395,7 +395,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned
unsigned char buffer[256];
int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING);
if (ret == -1) {
MS_LOG(ERROR) << "rsa public decrypt failed.";
MS_LOG(WARNING) << "rsa public decrypt failed.";
result = false;
break;
}
@@ -405,9 +405,9 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned
if (ret != 1) {
uint64_t ulErr = ERR_get_error();
char szErrMsg[1024] = {0};
MS_LOG(ERROR) << "verify error. error number: " << ulErr;
MS_LOG(WARNING) << "verify WARNING. WARNING number: " << ulErr;
std::string str_res = ERR_error_string(ulErr, szErrMsg);
MS_LOG(ERROR) << szErrMsg;
MS_LOG(WARNING) << szErrMsg;
if (str_res.empty()) {
result = false;
break;
@@ -420,7 +420,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned
X509_free(keyAttestationCertObj);
CRYPTO_cleanup_all_ex_data();

MS_LOG(INFO) << "verifyRSAKey success.";
MS_LOG(WARNING) << "verifyRSAKey end.";
return result;
}

@@ -445,7 +445,7 @@ void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash

std::string CertVerify::toHexString(const unsigned char *data, const int len) {
if (data == nullptr) {
MS_LOG(ERROR) << "data hash is null.";
MS_LOG(WARNING) << "data hash is null.";
return "";
}

@@ -465,10 +465,10 @@ bool CertVerify::verifyEquipCertAndFlID(const std::string &flID, const std::stri
sha256Hash(equipCert, hash, SHA256_DIGEST_LENGTH);
std::string equipCertSha256 = toHexString(hash, SHA256_DIGEST_LENGTH);
if (flID == equipCertSha256) {
MS_LOG(INFO) << "verifyEquipCertAndFlID success.";
MS_LOG(WARNING) << "verifyEquipCertAndFlID success.";
return true;
} else {
MS_LOG(ERROR) << "verifyEquipCertAndFlID failed.";
MS_LOG(WARNING) << "verifyEquipCertAndFlID failed.";
return false;
}
}
@@ -482,13 +482,13 @@ bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &tim
return false;
}
int64_t now = tv.tv_sec * base + tv.tv_usec / base;
MS_LOG(INFO) << "flID: " << flID.c_str() << ",now time: " << now << ",requestTime: " << requestTime;
MS_LOG(WARNING) << "flID: " << flID.c_str() << ",now time: " << now << ",requestTime: " << requestTime;

int64_t diff = now - requestTime;
if (abs(diff) > replayAttackTimeDiff) {
return false;
}
MS_LOG(INFO) << "verifyTimeStamp success.";
MS_LOG(WARNING) << "verifyTimeStamp success.";
return true;
}

@@ -514,7 +514,7 @@ void CertVerify::sha256Hash(const std::string &src, uint8_t *hash, const int len
bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData,
int srcDataLen) {
if (keyAttestation.empty() || signData == nullptr || srcData == nullptr || srcDataLen <= 0) {
MS_LOG(ERROR) << "keyAttestation or signData or srcData is invalid.";
MS_LOG(WARNING) << "keyAttestation or signData or srcData is invalid.";
return false;
}
bool result = true;
@@ -525,7 +525,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *
pubKey = X509_get_pubkey(keyAttestationCertObj);
RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey);
if (pRSAPublicKey == nullptr) {
MS_LOG(ERROR) << "get rsa public key failed.";
MS_LOG(WARNING) << "get rsa public key failed.";
result = false;
break;
}
@@ -534,7 +534,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *
unsigned char buffer[256];
int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING);
if (ret == -1) {
MS_LOG(ERROR) << "rsa public decrypt failed.";
MS_LOG(WARNING) << "rsa public decrypt failed.";
result = false;
break;
}
@@ -544,9 +544,9 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *
if (ret != 1) {
uint64_t ulErr = ERR_get_error();
char szErrMsg[1024] = {0};
MS_LOG(ERROR) << "verify error. error number: " << ulErr;
MS_LOG(WARNING) << "verify WARNING. WARNING number: " << ulErr;
std::string str_res = ERR_error_string(ulErr, szErrMsg);
MS_LOG(ERROR) << szErrMsg;
MS_LOG(WARNING) << szErrMsg;
if (str_res.empty()) {
result = false;
break;
@@ -559,23 +559,23 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *
X509_free(keyAttestationCertObj);
CRYPTO_cleanup_all_ex_data();

MS_LOG(INFO) << "verifyRSAKey success.";
MS_LOG(WARNING) << "verifyRSAKey end.";
return result;
}

bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
const std::string equipCrlPath, const uint64_t replay_attack_time_diff) {
if (rootFirstCaFilePath.empty() || rootSecondCaFilePath.empty()) {
MS_LOG(ERROR) << "the root or crl path is empty.";
MS_LOG(WARNING) << "the root or crl path is empty.";
return false;
}

if (!checkFileExists(rootFirstCaFilePath)) {
MS_LOG(ERROR) << "The rootFirstCaFilePath is not exist.";
MS_LOG(WARNING) << "The rootFirstCaFilePath is not exist.";
return false;
}
if (!checkFileExists(rootSecondCaFilePath)) {
MS_LOG(ERROR) << "The rootSecondCaFilePath is not exist.";
MS_LOG(WARNING) << "The rootSecondCaFilePath is not exist.";
return false;
}



+ 11
- 11
mindspore/ccsrc/fl/server/collective_ops_impl.cc View File

@@ -63,10 +63,10 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
T *output_buff = reinterpret_cast<T *>(recvbuff);
uint32_t send_to_rank = (rank_id_ + 1) % rank_size;
uint32_t recv_from_rank = (rank_id_ - 1 + rank_size) % rank_size;
MS_LOG(INFO) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", rank_id_:" << rank_id_
<< ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size
<< ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
<< ", recv_from_rank:" << recv_from_rank;
MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", rank_id_:" << rank_id_
<< ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size
<< ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
<< ", recv_from_rank:" << recv_from_rank;

// Ring ReduceScatter.
MS_LOG(DEBUG) << "Start Ring ReduceScatter.";
@@ -148,8 +148,8 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
uint32_t rank_size = server_num_;
MS_LOG(INFO) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", rank_id_:" << rank_id_
<< ", count:" << count;
MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", rank_id_:" << rank_id_
<< ", count:" << count;

size_t src_size = count * sizeof(T);
size_t dst_size = count * sizeof(T);
@@ -192,7 +192,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
MS_LOG(DEBUG) << "End Reduce.";

// Broadcast data to not 0 rank process.
MS_LOG(INFO) << "Start broadcast from rank 0 to other processes.";
MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes.";
if (rank_id_ == 0) {
for (uint32_t i = 1; i < rank_size; i++) {
MS_LOG(DEBUG) << "Broadcast data to process " << i;
@@ -240,9 +240,9 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff

uint32_t send_to_rank = (rank_id_ + 1) % rank_size_;
uint32_t recv_from_rank = (rank_id_ - 1 + rank_size_) % rank_size_;
MS_LOG(INFO) << "Ring AllGather count:" << send_count << ", rank_size:" << rank_size_ << ", rank_id_:" << rank_id_
<< ", chunk_size:" << chunk_size << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
<< ", recv_from_rank:" << recv_from_rank;
MS_LOG(DEBUG) << "Ring AllGather count:" << send_count << ", rank_size:" << rank_size_ << ", rank_id_:" << rank_id_
<< ", chunk_size:" << chunk_size << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
<< ", recv_from_rank:" << recv_from_rank;

T *output_buff = reinterpret_cast<T *>(recvbuff);
size_t src_size = send_count * sizeof(T);
@@ -301,7 +301,7 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c
uint32_t global_root_rank = group_to_global_ranks[root];

// Broadcast data to processes which are not the root.
MS_LOG(INFO) << "Start broadcast from root to other processes.";
MS_LOG(DEBUG) << "Start broadcast from root to other processes.";
if (rank_id_ == global_root_rank) {
for (uint32_t i = 1; i < group_rank_size; i++) {
uint32_t dst_rank = group_to_global_ranks[i];


+ 25
- 24
mindspore/ccsrc/fl/server/distributed_count_service.cc View File

@@ -81,16 +81,16 @@ bool DistributedCountService::ReInitCounter(const std::string &name, size_t glob
}

bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
MS_LOG(DEBUG) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
MS_LOG(WARNING) << "Counter for " << name << " is not registered.";
return false;
}

std::unique_lock<std::mutex> lock(mutex_[name]);
if (global_current_count_[name].size() >= global_threshold_count_[name]) {
MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is "
MS_LOG(DEBUG) << "Count for " << name << " is already enough. Threshold count is "
<< global_threshold_count_[name];
return false;
}
@@ -98,7 +98,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
(void)global_current_count_[name].insert(id);
if (!TriggerCounterEvent(name, reason)) {
MS_LOG(ERROR) << "Leader server trigger count event failed.";
MS_LOG(WARNING) << "Leader server trigger count event failed.";
return false;
}
} else {
@@ -110,7 +110,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
&report_cnt_rsp_msg)) {
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
MS_LOG(WARNING) << "Sending reporting count message to leader server failed for " << name;
if (reason != nullptr) {
*reason = kNetworkError;
}
@@ -121,7 +121,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
CountResponse count_rsp;
(void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
if (!count_rsp.result()) {
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
MS_LOG(WARNING) << "Reporting count failed:" << count_rsp.reason();
// If the error is caused by the network issue, return the reason.
if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
*reason = kNetworkError;
@@ -133,10 +133,10 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
}

bool DistributedCountService::CountReachThreshold(const std::string &name) {
MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
MS_LOG(DEBUG) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
MS_LOG(WARNING) << "Counter for " << name << " is not registered.";
return false;
}

@@ -149,7 +149,8 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
MS_LOG(WARNING) << "Sending querying whether count reaches threshold message to leader server failed for "
<< name;
return false;
}

@@ -202,10 +203,10 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
std::string reason = "Counter for " + name + " is not registered.";
count_rsp.set_result(false);
count_rsp.set_reason(reason);
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
@@ -217,10 +218,10 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
"Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]);
count_rsp.set_result(false);
count_rsp.set_reason(reason);
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
@@ -239,7 +240,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
}
if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
@@ -254,7 +255,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(

std::unique_lock<std::mutex> lock(mutex_[name]);
if (global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
MS_LOG(WARNING) << "Counter for " << name << " is not registered.";
return;
}

@@ -262,7 +263,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(
count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
if (!communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(),
count_reach_threshold_rsp.SerializeAsString().size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
@@ -274,7 +275,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
// callbacks.
std::string couter_event_rsp_msg = "success";
if (!communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}

@@ -284,7 +285,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
const auto &name = counter_event.name();

if (counter_handlers_.count(name) == 0) {
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
MS_LOG(WARNING) << "The counter handler of " << name << " is not registered.";
return;
}
MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
@@ -293,7 +294,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
} else if (type == CounterEventType::LAST_CNT) {
counter_handlers_[name].last_count_handler(message);
} else {
MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid.";
MS_LOG(WARNING) << "DistributedCountService event type " << type << " is invalid.";
return;
}
return;
@@ -301,7 +302,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:

bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "The counter of " << name << " is not registered.";
MS_LOG(WARNING) << "The counter of " << name << " is not registered.";
return false;
}

@@ -331,7 +332,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
for (uint32_t i = 1; i < server_num_; i++) {
MS_LOG(INFO) << "Start sending first count event message to server " << i;
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
MS_LOG(WARNING) << "Activating first count event to server " << i << " failed.";
if (reason != nullptr) {
*reason = kNetworkError;
}
@@ -340,7 +341,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
}

if (counter_handlers_.count(name) == 0) {
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
MS_LOG(WARNING) << "The counter handler of " << name << " is not registered.";
return false;
}
// Leader server directly calls the callback.
@@ -360,7 +361,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
for (uint32_t i = 1; i < server_num_; i++) {
MS_LOG(INFO) << "Start sending last count event message to server " << i;
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
MS_LOG(WARNING) << "Activating last count event to server " << i << " failed.";
if (reason != nullptr) {
*reason = kNetworkError;
}
@@ -369,7 +370,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
}

if (counter_handlers_.count(name) == 0) {
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
MS_LOG(WARNING) << "The counter handler of " << name << " is not registered.";
return false;
}
// Leader server directly calls the callback.


+ 26
- 26
mindspore/ccsrc/fl/server/distributed_metadata_store.cc View File

@@ -43,7 +43,7 @@ void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps:

void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
MS_LOG(WARNING) << "The consistent hash ring is not initialized yet.";
return;
}

@@ -63,14 +63,14 @@ void DistributedMetadataStore::RegisterMetadata(const std::string &name, const P

void DistributedMetadataStore::ResetMetadata(const std::string &name) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
MS_LOG(WARNING) << "The consistent hash ring is not initialized yet.";
return;
}

uint32_t stored_rank = router_->Find(name);
if (local_rank_ == stored_rank) {
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata for " << name << " is not registered.";
MS_LOG(WARNING) << "The metadata for " << name << " is not registered.";
return;
}

@@ -84,15 +84,15 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {

bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
MS_LOG(WARNING) << "The consistent hash ring is not initialized yet.";
return false;
}

uint32_t stored_rank = router_->Find(name);
MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank;
MS_LOG(DEBUG) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank;
if (local_rank_ == stored_rank) {
if (!DoUpdateMetadata(name, meta)) {
MS_LOG(ERROR) << "Updating meta data failed.";
MS_LOG(WARNING) << "Updating meta data failed.";
return false;
}
} else {
@@ -102,7 +102,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
&update_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
MS_LOG(WARNING) << "Sending updating metadata message to server " << stored_rank << " failed.";
if (reason != nullptr) {
*reason = kNetworkError;
}
@@ -113,7 +113,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
std::string update_meta_rsp =
std::string(reinterpret_cast<char *>(update_meta_rsp_msg->data()), update_meta_rsp_msg->size());
if (update_meta_rsp != kSuccess) {
MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed. " << update_meta_rsp;
MS_LOG(WARNING) << "Updating metadata in server " << stored_rank << " failed. " << update_meta_rsp;
return false;
}
}
@@ -122,11 +122,11 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM

PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
MS_LOG(WARNING) << "The consistent hash ring is not initialized yet.";
return {};
}
uint32_t stored_rank = router_->Find(name);
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
MS_LOG(DEBUG) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
if (local_rank_ == stored_rank) {
std::unique_lock<std::mutex> lock(mutex_[name]);
return metadata_[name];
@@ -138,7 +138,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata,
&get_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
MS_LOG(WARNING) << "Sending getting metadata message to server " << stored_rank << " failed.";
return get_metadata_rsp;
}

@@ -184,17 +184,17 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
PBMetadataWithName meta_with_name;
(void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = meta_with_name.name();
MS_LOG(INFO) << "Update metadata for " << name;
MS_LOG(DEBUG) << "Update metadata for " << name;

std::string update_meta_rsp_msg;
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
update_meta_rsp_msg = "Updating meta data failed.";
MS_LOG(ERROR) << update_meta_rsp_msg;
MS_LOG(WARNING) << update_meta_rsp_msg;
} else {
update_meta_rsp_msg = "Success";
}
if (!communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
@@ -205,17 +205,17 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
GetMetadataRequest get_metadata_req;
(void)get_metadata_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = get_metadata_req.name();
MS_LOG(INFO) << "Getting metadata for " << name;
MS_LOG(DEBUG) << "Getting metadata for " << name;

std::unique_lock<std::mutex> lock(mutex_[name]);
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
MS_LOG(WARNING) << "The metadata of " << name << " is not registered.";
return;
}
PBMetadata stored_meta = metadata_[name];
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
@@ -224,7 +224,7 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
std::unique_lock<std::mutex> lock(mutex_[name]);
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
MS_LOG(WARNING) << "The metadata of " << name << " is not registered.";
return false;
}
if (meta.has_device_meta()) {
@@ -265,13 +265,13 @@ bool DistributedMetadataStore::DoUpdateEncryptMetadata(const std::string &name,
if (meta.has_pair_client_keys()) {
bool keys_update_succeed = UpdatePairClientKeys(name, meta);
if (!keys_update_succeed) {
MS_LOG(ERROR) << "Update pair_client_keys failed.";
MS_LOG(WARNING) << "Update pair_client_keys failed.";
return false;
}
} else if (meta.has_pair_client_shares()) {
bool shares_update_succeed = UpdatePairClientShares(name, meta);
if (!shares_update_succeed) {
MS_LOG(ERROR) << "Update pair_client_shares failed.";
MS_LOG(WARNING) << "Update pair_client_shares failed.";
return false;
}
} else if (meta.has_one_client_noises()) {
@@ -305,8 +305,8 @@ bool DistributedMetadataStore::DoUpdateEncryptMetadata(const std::string &name,
auto &certificate = meta.pair_key_attestation().certificate();
key_attestation_map[fl_id] = certificate;
} else {
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value is not defined.";
MS_LOG(WARNING) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value is not defined.";
return false;
}
return true;
@@ -321,8 +321,8 @@ bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, con
for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
MS_LOG(WARNING) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}
@@ -344,8 +344,8 @@ bool DistributedMetadataStore::UpdatePairClientShares(const std::string &name, c
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
MS_LOG(WARNING) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}


+ 1
- 1
mindspore/ccsrc/fl/server/iteration.cc View File

@@ -435,7 +435,7 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps
notify_leader_to_next_iter_rsp.set_result("success");
if (!communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}



+ 6
- 0
mindspore/ccsrc/fl/server/iteration_timer.cc View File

@@ -31,6 +31,8 @@ IterationTimer::~IterationTimer() {
}

void IterationTimer::Start(const std::chrono::milliseconds &duration) {
std::unique_lock<std::mutex> lock(timer_mtx_);
MS_LOG(INFO) << "The timer begin to start.";
if (running_.load()) {
MS_LOG(WARNING) << "The timer already started.";
return;
@@ -47,13 +49,17 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
});
MS_LOG(INFO) << "The timer start success.";
}

void IterationTimer::Stop() {
std::unique_lock<std::mutex> lock(timer_mtx_);
MS_LOG(INFO) << "The timer begin to stop.";
running_ = false;
if (monitor_thread_.joinable()) {
monitor_thread_.join();
}
MS_LOG(INFO) << "The timer stop success.";
}

void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) {


+ 2
- 0
mindspore/ccsrc/fl/server/iteration_timer.h View File

@@ -57,6 +57,8 @@ class IterationTimer {
// The thread that keeps timing and call timeout_callback_ when the timer expires.
std::thread monitor_thread_;
TimeOutCb timeout_callback_;

std::mutex timer_mtx_;
};
} // namespace server
} // namespace fl


+ 3
- 3
mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h View File

@@ -135,9 +135,9 @@ class FedAvgKernel : public AggregationKernel {
ClearWeightAndDataSize();
}

MS_LOG(INFO) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for "
<< name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is "
<< data_size_addr[0];
MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for "
<< name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is "
<< data_size_addr[0];
for (size_t i = 0; i < inputs[2]->size / sizeof(T); i++) {
weight_addr[i] += new_weight_addr[i];
}


+ 6
- 6
mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc View File

@@ -61,7 +61,7 @@ bool GetModelKernel::Launch(const uint8_t *req_data, size_t len,

retry_count_ += 1;
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
MS_LOG(INFO) << "Launching GetModelKernel kernel. Retry count is " << retry_count_.load();
MS_LOG(DEBUG) << "Launching GetModelKernel kernel. Retry count is " << retry_count_.load();
}

const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
@@ -98,22 +98,22 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
std::to_string(next_req_time));
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
MS_LOG(WARNING) << reason;
MS_LOG(DEBUG) << reason;
}
return;
}

if (iter_to_model.count(get_model_iter) == 0) {
// If the model of get_model_iter is not stored, return the latest version of model and current iteration number.
MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter)
<< " is invalid. Current iteration is " << std::to_string(current_iter);
MS_LOG(DEBUG) << "The iteration of GetModel request " << std::to_string(get_model_iter)
<< " is invalid. Current iteration is " << std::to_string(current_iter);
feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
} else {
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
}
IncreaseAcceptClientNum();
MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
MS_LOG(DEBUG) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
current_iter, feature_maps, std::to_string(next_req_time));
return;


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.h View File

@@ -30,7 +30,7 @@ namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500;
constexpr uint32_t kPrintPullWeightForEveryRetryTime = 3000;
class PullWeightKernel : public RoundKernel {
public:
PullWeightKernel() : executor_(nullptr), retry_count_(0) {}


+ 12
- 12
mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc View File

@@ -53,7 +53,7 @@ void StartFLJobKernel::InitKernel(size_t) {

bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Launching StartFLJobKernel kernel.";
MS_LOG(DEBUG) << "Launching StartFLJobKernel kernel.";
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
@@ -66,7 +66,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
if (!verifier.VerifyBuffer<schema::RequestFLJob>()) {
std::string reason = "The schema of RequestFLJob is invalid.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, "");
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@@ -83,7 +83,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
BuildStartFLJobRsp(
fbb, schema::ResponseCode_RequestError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
@@ -143,7 +143,7 @@ bool StartFLJobKernel::JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb,
BuildStartFLJobRsp(
fbb, schema::ResponseCode_RequestError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
return false;
}
unsigned char sign_data[sign_data_vector->size()];
@@ -172,9 +172,9 @@ bool StartFLJobKernel::JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb,
BuildStartFLJobRsp(
fbb, schema::ResponseCode_RequestError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
} else {
MS_LOG(INFO) << "JudgeFLJobVerify success." << ret;
MS_LOG(DEBUG) << "JudgeFLJobVerify success." << ret;
}

return ret;
@@ -201,7 +201,7 @@ bool StartFLJobKernel::StoreKeyAttestation(const std::shared_ptr<FBBuilder> &fbb
bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientKeyAttestation, pb_data);
if (!ret) {
std::string reason = "startFLJob: store key attestation failed";
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
@@ -232,7 +232,7 @@ ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<F
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
MS_LOG(DEBUG) << reason;
return ResultCode::kSuccessAndReturn;
}
return ResultCode::kSuccess;
@@ -246,7 +246,7 @@ DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *st
std::string fl_name = start_fl_job_req->fl_name()->str();
std::string fl_id = start_fl_job_req->fl_id()->str();
int data_size = start_fl_job_req->data_size();
MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size;
MS_LOG(DEBUG) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size;

DeviceMeta device_meta;
device_meta.set_fl_name(fl_name);
@@ -266,7 +266,7 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder>
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
MS_LOG(DEBUG) << reason;
}
return ret;
}
@@ -282,7 +282,7 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder>
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
}
return ResultCode::kSuccess;
@@ -305,7 +305,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
const std::string &next_req_time,
std::map<std::string, AddressPtr> feature_maps) {
if (fbb == nullptr) {
MS_LOG(ERROR) << "Input fbb is nullptr.";
MS_LOG(WARNING) << "Input fbb is nullptr.";
return;
}
auto fbs_reason = fbb->CreateString(reason);


+ 21
- 5
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc View File

@@ -45,7 +45,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {

bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
MS_LOG(DEBUG) << "Launching UpdateModelKernel kernel.";

std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
@@ -100,7 +100,8 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
MS_LOG(INFO) << "verify signature passed!";
}

result_code = UpdateModel(update_model_req, fbb);
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
result_code = VerifyUpdateModel(update_model_req, fbb, device_metas);
if (result_code != ResultCode::kSuccess) {
MS_LOG(ERROR) << "Updating model failed.";
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
@@ -112,6 +113,14 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}

result_code = UpdateModel(update_model_req, fbb, device_metas);
if (result_code != ResultCode::kSuccess) {
MS_LOG(ERROR) << "Updating model failed.";
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}

IncreaseAcceptClientNum();
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
@@ -155,8 +164,8 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr
return ResultCode::kSuccess;
}

ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb) {
ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
size_t iteration = IntToSize(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
@@ -169,7 +178,6 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
return ResultCode::kSuccessAndReturn;
}

PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta();
std::string update_model_fl_id = update_model_req->fl_id()->str();
MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
@@ -198,7 +206,15 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
return ResultCode::kSuccessAndReturn;
}
}
return ResultCode::kSuccess;
}

ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta();
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kSuccessAndReturn);
std::string update_model_fl_id = update_model_req->fl_id()->str();
size_t data_size = fl_id_to_meta.at(update_model_fl_id).data_size();
const auto &feature_map = ParseFeatureMap(update_model_req);
if (feature_map.empty()) {


+ 4
- 2
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h View File

@@ -52,14 +52,16 @@ class UpdateModelKernel : public RoundKernel {

private:
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb,
const PBMetadata &device_metas);
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req);
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time);

ResultCode VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas);
// The executor is for updating the model for updateModel request.
Executor *executor_{nullptr};



+ 2
- 2
mindspore/ccsrc/fl/server/memory_register.h View File

@@ -43,7 +43,7 @@ class MemoryRegister {
// avoid its data being released.
template <typename T>
void RegisterArray(const std::string &name, std::unique_ptr<T[]> *array, size_t size) {
MS_EXCEPTION_IF_NULL(array);
MS_ERROR_IF_NULL_WO_RET_VAL(array);
void *data = array->get();
AddressPtr addr = std::make_shared<Address>();
addr->addr = data;
@@ -62,7 +62,7 @@ class MemoryRegister {
auto char_arr = CastUniquePtr<char, T>(array);
StoreCharArray(&char_arr);
} else {
MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name();
MS_LOG(WARNING) << "MemoryRegister does not support type " << typeid(T).name();
return;
}



+ 3
- 3
mindspore/ccsrc/fl/server/model_store.cc View File

@@ -90,7 +90,7 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
std::unique_lock<std::mutex> lock(model_mtx_);
std::map<std::string, AddressPtr> model = {};
if (iteration_to_model_.count(iteration) == 0) {
MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored.";
MS_LOG(WARNING) << "Model for iteration " << iteration << " is not stored. Return latest model";
return model;
}
model = iteration_to_model_[iteration]->addresses();
@@ -114,8 +114,8 @@ size_t ModelStore::model_size() const { return model_size_; }
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
if (model.empty()) {
MS_LOG(EXCEPTION) << "Model feature map is empty.";
return nullptr;
MS_LOG(WARNING) << "Model feature map is empty.";
return std::make_shared<MemoryRegister>();
}

// Assign new memory for the model.


+ 9
- 3
mindspore/ccsrc/fl/server/round.cc View File

@@ -25,6 +25,8 @@ namespace fl {
namespace server {
class Server;
class Iteration;
std::atomic<uint32_t> kPrintTimes = 0;
const uint32_t kPrintTimesThreshold = 3000;
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
bool server_num_as_threshold)
: name_(name),
@@ -99,7 +101,7 @@ bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t
threshold_count_ = updated_threshold_count;
if (check_count_) {
if (!DistributedCountService::GetInstance().ReInitCounter(name_, threshold_count_)) {
MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
MS_LOG(WARNING) << "Reinitializing count for " << name_ << " failed.";
return false;
}
}
@@ -124,7 +126,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
std::string reason = "";
if (!IsServerAvailable(&reason)) {
if (!message->SendResponse(reason.c_str(), reason.size())) {
MS_LOG(ERROR) << "Sending response failed.";
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
@@ -198,7 +200,11 @@ bool Round::IsServerAvailable(std::string *reason) {

// If the server is still in safemode, reject the request.
if (Server::GetInstance().IsSafeMode()) {
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
if (kPrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
kPrintTimes = 0;
}
kPrintTimes += 1;
*reason = ps::kClusterSafeMode;
return false;
}


+ 4
- 0
mindspore/ccsrc/fl/server/server_recovery.cc View File

@@ -22,6 +22,7 @@ namespace mindspore {
namespace fl {
namespace server {
bool ServerRecovery::Initialize(const std::string &config_file) {
std::unique_lock<std::mutex> lock(server_recovery_file_mtx_);
config_ = std::make_unique<ps::core::FileConfiguration>(config_file);
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
@@ -59,6 +60,7 @@ bool ServerRecovery::Initialize(const std::string &config_file) {
}

bool ServerRecovery::Recover() {
std::unique_lock<std::mutex> lock(server_recovery_file_mtx_);
server_recovery_file_.open(server_recovery_file_path_, std::ios::in);
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
MS_LOG(WARNING) << "Can't open server recovery file " << server_recovery_file_path_;
@@ -80,6 +82,7 @@ bool ServerRecovery::Recover() {
}

bool ServerRecovery::Save(uint64_t current_iter) {
std::unique_lock<std::mutex> lock(server_recovery_file_mtx_);
server_recovery_file_.open(server_recovery_file_path_, std::ios::out | std::ios::ate);
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
MS_LOG(WARNING) << "Can't save data to recovery file " << server_recovery_file_path_
@@ -96,6 +99,7 @@ bool ServerRecovery::Save(uint64_t current_iter) {

bool ServerRecovery::SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator,
uint32_t rank_id) {
std::unique_lock<std::mutex> lock(server_recovery_file_mtx_);
// If this server is follower server, notify leader server that this server has recovered.
if (rank_id != kLeaderServerRank) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator, false);


+ 2
- 0
mindspore/ccsrc/fl/server/server_recovery.h View File

@@ -22,6 +22,7 @@
#include <memory>
#include <string>
#include <vector>
#include <mutex>
#include "ps/core/recovery_base.h"
#include "ps/core/file_configuration.h"
#include "ps/core/communicator/tcp_communicator.h"
@@ -58,6 +59,7 @@ class ServerRecovery : public ps::core::RecoveryBase {

// The server recovery file object.
std::fstream server_recovery_file_;
std::mutex server_recovery_file_mtx_;
};
} // namespace server
} // namespace fl


+ 3
- 8
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -669,7 +669,6 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con

std::lock_guard<std::mutex> lock(client_mutex_);
connected_nodes_.clear();
PersistMetaData();
}

void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
@@ -696,7 +695,6 @@ void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &con
}
is_ready_ = true;
UpdateClusterState(ClusterState::CLUSTER_READY);
PersistMetaData();
}

void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
@@ -710,7 +708,6 @@ void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn
}
is_ready_ = true;
UpdateClusterState(ClusterState::CLUSTER_READY);
PersistMetaData();
}

void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
@@ -860,8 +857,7 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {

void AbstractNode::InitClientToServer() {
// create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed
client_to_server_ =
std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get(), node_info_.node_role_);
client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, node_info_.node_role_);
MS_EXCEPTION_IF_NULL(client_to_server_);
client_to_server_->Init();
MS_LOG(INFO) << "The node start a tcp client to this node!";
@@ -872,8 +868,7 @@ bool AbstractNode::InitClientToScheduler() {
MS_LOG(WARNING) << "The config is empty.";
return false;
}
client_to_scheduler_ =
std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get(), NodeRole::SCHEDULER);
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, NodeRole::SCHEDULER);
MS_EXCEPTION_IF_NULL(client_to_scheduler_);
client_to_scheduler_->SetMessageCallback(
[&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
@@ -931,7 +926,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
MS_LOG(INFO) << "Create tcp client for role: " << role << ", rank: " << rank_id;
std::string ip = nodes_address_[key].first;
uint16_t port = nodes_address_[key].second;
auto client = std::make_shared<TcpClient>(ip, port, config_.get(), role);
auto client = std::make_shared<TcpClient>(ip, port, role);
MS_EXCEPTION_IF_NULL(client);
client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {


+ 2
- 3
mindspore/ccsrc/ps/core/communicator/tcp_client.cc View File

@@ -37,15 +37,14 @@ event_base *TcpClient::event_base_ = nullptr;
std::mutex TcpClient::event_base_mutex_;
bool TcpClient::is_started_ = false;

TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role)
TcpClient::TcpClient(const std::string &address, std::uint16_t port, NodeRole peer_role)
: event_timeout_(nullptr),
buffer_event_(nullptr),
server_address_(std::move(address)),
server_port_(port),
peer_role_(peer_role),
is_stop_(true),
is_connected_(false),
config_(config) {
is_connected_(false) {
message_handler_.SetCallback(
[this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
if (message_callback_) {


+ 1
- 1
mindspore/ccsrc/ps/core/communicator/tcp_client.h View File

@@ -54,7 +54,7 @@ class TcpClient {
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
using OnTimer = std::function<void()>;

explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role);
explicit TcpClient(const std::string &address, std::uint16_t port, NodeRole peer_role);
virtual ~TcpClient();

std::string GetServerAddress() const;


+ 3
- 2
mindspore/ccsrc/ps/core/communicator/tcp_server.cc View File

@@ -239,6 +239,7 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConn

void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Remove connection fd: " << fd;
connections_.erase(fd);
}

@@ -383,7 +384,7 @@ void TcpServer::EventCallbackInner(struct bufferevent *bev, std::int16_t events,
MS_EXCEPTION_IF_NULL(srv);

if (events & BEV_EVENT_EOF) {
MS_LOG(INFO) << "Event buffer end of file, a client is disconnected from this server!";
MS_LOG(INFO) << "BEV_EVENT_EOF event is trigger!";
// Notify about disconnection
if (srv->client_disconnection_) {
srv->client_disconnection_(*srv, *conn);
@@ -391,7 +392,7 @@ void TcpServer::EventCallbackInner(struct bufferevent *bev, std::int16_t events,
// Free connection structures
srv->RemoveConnection(conn->GetFd());
} else if (events & BEV_EVENT_ERROR) {
MS_LOG(WARNING) << "Connect to server error.";
MS_LOG(WARNING) << "BEV_EVENT_ERROR event is trigger!";
if (PSContext::instance()->enable_ssl()) {
uint64_t err = bufferevent_get_openssl_error(bev);
MS_LOG(WARNING) << "The error number is:" << err;


+ 2
- 2
mindspore/ccsrc/ps/core/node.cc View File

@@ -80,8 +80,8 @@ bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::
if (!client->SendMessage(meta, protos, data, size)) {
MS_LOG(WARNING) << "Client send message failed.";
}
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}



+ 1
- 0
mindspore/ccsrc/ps/core/node_recovery.cc View File

@@ -20,6 +20,7 @@ namespace mindspore {
namespace ps {
namespace core {
bool NodeRecovery::Recover() {
std::unique_lock<std::mutex> lock(recovery_mtx_);
if (recovery_storage_ == nullptr) {
return false;
}


+ 5
- 2
mindspore/ccsrc/ps/core/recovery_base.cc View File

@@ -20,6 +20,7 @@ namespace mindspore {
namespace ps {
namespace core {
bool RecoveryBase::Initialize(const std::string &config_json) {
std::unique_lock<std::mutex> lock(recovery_mtx_);
nlohmann::json recovery_config;
try {
recovery_config = nlohmann::json::parse(config_json);
@@ -87,7 +88,8 @@ bool RecoveryBase::InitializeNodes(const std::string &config_json) {
return true;
}

void RecoveryBase::Persist(const core::ClusterConfig &clusterConfig) const {
void RecoveryBase::Persist(const core::ClusterConfig &clusterConfig) {
std::unique_lock<std::mutex> lock(recovery_mtx_);
if (recovery_storage_ == nullptr) {
MS_LOG(WARNING) << "recovery storage is null, so don't persist meta data";
return;
@@ -95,7 +97,8 @@ void RecoveryBase::Persist(const core::ClusterConfig &clusterConfig) const {
recovery_storage_->PersistFile(clusterConfig);
}

void RecoveryBase::PersistNodesInfo(const core::ClusterConfig &clusterConfig) const {
void RecoveryBase::PersistNodesInfo(const core::ClusterConfig &clusterConfig) {
std::unique_lock<std::mutex> lock(recovery_mtx_);
if (scheduler_recovery_storage_ == nullptr) {
MS_LOG(WARNING) << "scheduler recovery storage is null, so don't persist nodes meta data";
return;


+ 5
- 2
mindspore/ccsrc/ps/core/recovery_base.h View File

@@ -22,6 +22,7 @@
#include <memory>
#include <string>
#include <vector>
#include <mutex>

#include "ps/constants.h"
#include "utils/log_adapter.h"
@@ -50,10 +51,10 @@ class RecoveryBase {
virtual bool Recover() = 0;

// Persist metadata to storage.
virtual void Persist(const core::ClusterConfig &clusterConfig) const;
virtual void Persist(const core::ClusterConfig &clusterConfig);

// Persist metadata to storage.
virtual void PersistNodesInfo(const core::ClusterConfig &clusterConfig) const;
virtual void PersistNodesInfo(const core::ClusterConfig &clusterConfig);

protected:
// Persistent storage used to save metadata.
@@ -64,6 +65,8 @@ class RecoveryBase {

// Storage type for recovery,Currently only supports storage of file types
StorageType storage_type_;

std::mutex recovery_mtx_;
};
} // namespace core
} // namespace ps


+ 4
- 4
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -67,8 +67,8 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
void SchedulerNode::RunRecovery() {
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
// create tcp client to myself in case of event dispatch failed when Send reconnect msg to server failed
client_to_scheduler_ = std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port,
config_.get(), NodeRole::SCHEDULER);
client_to_scheduler_ =
std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port, NodeRole::SCHEDULER);
MS_EXCEPTION_IF_NULL(client_to_scheduler_);
client_to_scheduler_->Init();
client_thread_ = std::make_unique<std::thread>([this]() {
@@ -95,7 +95,7 @@ void SchedulerNode::RunRecovery() {
for (const auto &kvs : initial_node_infos) {
auto &node_id = kvs.first;
auto &node_info = kvs.second;
auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, config_.get(), node_info.node_role_);
auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, node_info.node_role_);
client->SetMessageCallback([this](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *, size_t) {
MS_LOG(INFO) << "received the response. ";
NotifyMessageArrival(meta);
@@ -647,7 +647,7 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf
std::string ip = node_info.ip_;
uint16_t port = node_info.port_;
MS_LOG(INFO) << "ip:" << ip << ", port:" << port << ", node id:" << node_info.node_id_;
auto client = std::make_shared<TcpClient>(ip, port, config_.get(), node_info.node_role_);
auto client = std::make_shared<TcpClient>(ip, port, node_info.node_role_);
MS_EXCEPTION_IF_NULL(client);
client->SetMessageCallback(
[&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {


+ 2
- 0
mindspore/ccsrc/ps/core/scheduler_recovery.cc View File

@@ -20,11 +20,13 @@ namespace mindspore {
namespace ps {
namespace core {
std::string SchedulerRecovery::GetMetadata(const std::string &key) {
std::unique_lock<std::mutex> lock(recovery_mtx_);
MS_EXCEPTION_IF_NULL(recovery_storage_);
return recovery_storage_->Get(key, "");
}

bool SchedulerRecovery::Recover() {
std::unique_lock<std::mutex> lock(recovery_mtx_);
if (recovery_storage_ == nullptr) {
return false;
}


+ 2
- 4
tests/ut/cpp/ps/core/tcp_client_tests.cc View File

@@ -28,8 +28,7 @@ class TestTcpClient : public UT::Common {
};

TEST_F(TestTcpClient, InitClientIPError) {
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, config.get(), NodeRole::SERVER);
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, NodeRole::SERVER);

client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
@@ -42,8 +41,7 @@ TEST_F(TestTcpClient, InitClientIPError) {
}

TEST_F(TestTcpClient, InitClientPortErrorNoException) {
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
auto client = std::make_unique<TcpClient>("127.0.0.1", -1, config.get(), NodeRole::SERVER);
auto client = std::make_unique<TcpClient>("127.0.0.1", -1, NodeRole::SERVER);

client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;


+ 1
- 1
tests/ut/cpp/ps/core/tcp_pb_server_test.cc View File

@@ -60,7 +60,7 @@ class TestTcpServer : public UT::Common {

TEST_F(TestTcpServer, ServerSendMessage) {
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), config.get(), NodeRole::SERVER);
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), NodeRole::SERVER);
std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {


Loading…
Cancel
Save