| @@ -72,11 +72,11 @@ bool CertVerify::verifyCertTime(const X509 *cert) const { | |||
| return false; | |||
| } | |||
| if (day < 0) { | |||
| MS_LOG(ERROR) << "cert start day time is later than now day time, day is" << day; | |||
| MS_LOG(WARNING) << "cert start day time is later than now day time, day is" << day; | |||
| return false; | |||
| } | |||
| if (day == 0 && sec < certStartTimeDiff) { | |||
| MS_LOG(ERROR) << "cert start second time is later than 600 second, second is" << sec; | |||
| MS_LOG(WARNING) << "cert start second time is later than 600 second, second is" << sec; | |||
| return false; | |||
| } | |||
| day = 0; | |||
| @@ -87,10 +87,10 @@ bool CertVerify::verifyCertTime(const X509 *cert) const { | |||
| } | |||
| if (day < 0 || sec < 0) { | |||
| MS_LOG(ERROR) << "cert end time is sooner than now time."; | |||
| MS_LOG(WARNING) << "cert end time is sooner than now time."; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "verify cert time success."; | |||
| MS_LOG(WARNING) << "verify cert time end."; | |||
| return true; | |||
| } | |||
| @@ -105,20 +105,20 @@ bool CertVerify::verifyPublicKey(const X509 *keyAttestationCertObj, const X509 * | |||
| int ret = 0; | |||
| ret = X509_verify(const_cast<X509 *>(keyAttestationCertObj), equipPubKey); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "keyAttestationCert verify is failed"; | |||
| MS_LOG(WARNING) << "keyAttestationCert verify is failed"; | |||
| result = false; | |||
| break; | |||
| } | |||
| ret = X509_verify(const_cast<X509 *>(equipCertObj), equipCAPubKey); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "equip cert verify is failed"; | |||
| MS_LOG(WARNING) << "equip cert verify is failed"; | |||
| result = false; | |||
| break; | |||
| } | |||
| int ret_first = X509_verify(const_cast<X509 *>(equipCACertObj), rootFirstPubKey); | |||
| int ret_second = X509_verify(const_cast<X509 *>(equipCACertObj), rootSecondPubKey); | |||
| if (ret_first != 1 && ret_second != 1) { | |||
| MS_LOG(ERROR) << "equip ca cert verify is failed"; | |||
| MS_LOG(WARNING) << "equip ca cert verify is failed"; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -128,7 +128,7 @@ bool CertVerify::verifyPublicKey(const X509 *keyAttestationCertObj, const X509 * | |||
| EVP_PKEY_free(equipCAPubKey); | |||
| EVP_PKEY_free(rootFirstPubKey); | |||
| EVP_PKEY_free(rootSecondPubKey); | |||
| MS_LOG(INFO) << "verify Public Key success."; | |||
| MS_LOG(WARNING) << "verify Public Key end."; | |||
| return result; | |||
| } | |||
| @@ -143,7 +143,7 @@ bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::str | |||
| bool result = true; | |||
| do { | |||
| if (rootFirstCA == nullptr || rootSecondCA == nullptr) { | |||
| MS_LOG(ERROR) << "rootFirstCA or rootSecondCA is nullptr"; | |||
| MS_LOG(WARNING) << "rootFirstCA or rootSecondCA is nullptr"; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -158,37 +158,37 @@ bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::str | |||
| } | |||
| if (!verifyCertCommonName(equipCACertObj, equipCertObj)) { | |||
| MS_LOG(ERROR) << "equip ca cert subject cn is not equal with equip cert issuer cn."; | |||
| MS_LOG(WARNING) << "equip ca cert subject cn is not equal with equip cert issuer cn."; | |||
| result = false; | |||
| break; | |||
| } | |||
| if (!verifyCertCommonName(rootFirstCA, equipCACertObj) && !verifyCertCommonName(rootSecondCA, equipCACertObj)) { | |||
| MS_LOG(ERROR) << "root CA cert subject cn is not equal with equip CA cert issuer cn."; | |||
| MS_LOG(WARNING) << "root CA cert subject cn is not equal with equip CA cert issuer cn."; | |||
| result = false; | |||
| break; | |||
| } | |||
| if (!verifyExtendedAttributes(equipCACertObj)) { | |||
| MS_LOG(ERROR) << "verify equipCACert Extended Attributes failed."; | |||
| MS_LOG(WARNING) << "verify equipCACert Extended Attributes failed."; | |||
| result = false; | |||
| break; | |||
| } | |||
| if (!verifyCertKeyID(rootFirstCA, equipCACertObj) && !verifyCertKeyID(rootSecondCA, equipCACertObj)) { | |||
| MS_LOG(ERROR) << "root CA cert subject keyid is not equal with equip CA cert issuer keyid."; | |||
| MS_LOG(WARNING) << "root CA cert subject keyid is not equal with equip CA cert issuer keyid."; | |||
| result = false; | |||
| break; | |||
| } | |||
| if (!verifyCertKeyID(equipCACertObj, equipCertObj)) { | |||
| MS_LOG(ERROR) << "equip CA cert subject keyid is not equal with equip cert issuer keyid."; | |||
| MS_LOG(WARNING) << "equip CA cert subject keyid is not equal with equip cert issuer keyid."; | |||
| result = false; | |||
| break; | |||
| } | |||
| if (!verifyPublicKey(keyAttestationCertObj, equipCertObj, equipCACertObj, rootFirstCA, rootSecondCA)) { | |||
| MS_LOG(ERROR) << "verify Public Key failed"; | |||
| MS_LOG(WARNING) << "verify Public Key failed"; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -198,7 +198,7 @@ bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::str | |||
| X509_free(keyAttestationCertObj); | |||
| X509_free(equipCertObj); | |||
| X509_free(equipCACertObj); | |||
| MS_LOG(INFO) << "verifyCAChain success."; | |||
| MS_LOG(WARNING) << "verifyCAChain end."; | |||
| return result; | |||
| } | |||
| @@ -232,7 +232,7 @@ bool CertVerify::verifyCertKeyID(const X509 *caCert, const X509 *subCert) const | |||
| } | |||
| char issuer_keyid[512] = {0}; | |||
| if (akeyid->keyid == nullptr) { | |||
| MS_LOG(ERROR) << "keyid is nullprt."; | |||
| MS_LOG(WARNING) << "keyid is nullprt."; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -271,11 +271,11 @@ bool CertVerify::verifyExtendedAttributes(const X509 *cert) const { | |||
| break; | |||
| } | |||
| if (!bcons->ca) { | |||
| MS_LOG(ERROR) << "Subject Type is End Entity."; | |||
| MS_LOG(WARNING) << "Subject Type is End Entity."; | |||
| result = false; | |||
| break; | |||
| } | |||
| MS_LOG(INFO) << "Subject Type is CA."; | |||
| MS_LOG(WARNING) << "Subject Type is CA."; | |||
| lASN1UsageStr = reinterpret_cast<ASN1_BIT_STRING *>(X509_get_ext_d2i(cert, NID_key_usage, NULL, NULL)); | |||
| if (lASN1UsageStr == nullptr) { | |||
| @@ -289,11 +289,11 @@ bool CertVerify::verifyExtendedAttributes(const X509 *cert) const { | |||
| } | |||
| if (!(usage & KU_KEY_CERT_SIGN)) { | |||
| MS_LOG(ERROR) << "Subject is not Certificate Signature."; | |||
| MS_LOG(WARNING) << "Subject is not Certificate Signature."; | |||
| result = false; | |||
| break; | |||
| } | |||
| MS_LOG(INFO) << "Subject is Certificate Signature."; | |||
| MS_LOG(WARNING) << "Subject is Certificate Signature."; | |||
| } while (0); | |||
| BASIC_CONSTRAINTS_free(bcons); | |||
| ASN1_BIT_STRING_free(lASN1UsageStr); | |||
| @@ -346,14 +346,14 @@ bool CertVerify::verifyCRL(const std::string &equipCert, const std::string &equi | |||
| } | |||
| if (equipCrl == nullptr) { | |||
| MS_LOG(INFO) << "equipCrl is nullptr. return true."; | |||
| MS_LOG(WARNING) << "equipCrl is nullptr. return true."; | |||
| result = true; | |||
| break; | |||
| } | |||
| evp_pkey = X509_get_pubkey(equipCertObj); | |||
| int ret = X509_CRL_verify(equipCrl, evp_pkey); | |||
| if (ret == 1) { | |||
| MS_LOG(ERROR) << "equip cert in equip crl, verify failed"; | |||
| MS_LOG(WARNING) << "equip cert in equip crl, verify failed"; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -362,14 +362,14 @@ bool CertVerify::verifyCRL(const std::string &equipCert, const std::string &equi | |||
| EVP_PKEY_free(evp_pkey); | |||
| X509_free(equipCertObj); | |||
| X509_CRL_free(equipCrl); | |||
| MS_LOG(INFO) << "verifyCRL success."; | |||
| MS_LOG(WARNING) << "verifyCRL end."; | |||
| return result; | |||
| } | |||
| bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID, | |||
| const std::string &timeStamp) { | |||
| if (keyAttestation.empty() || signData == nullptr || flID.empty() || timeStamp.empty()) { | |||
| MS_LOG(ERROR) << "keyAttestation or signData or flID or timeStamp is empty."; | |||
| MS_LOG(WARNING) << "keyAttestation or signData or flID or timeStamp is empty."; | |||
| return false; | |||
| } | |||
| bool result = true; | |||
| @@ -386,7 +386,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned | |||
| pubKey = X509_get_pubkey(keyAttestationCertObj); | |||
| RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey); | |||
| if (pRSAPublicKey == nullptr) { | |||
| MS_LOG(ERROR) << "get rsa public key failed."; | |||
| MS_LOG(WARNING) << "get rsa public key failed."; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -395,7 +395,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned | |||
| unsigned char buffer[256]; | |||
| int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING); | |||
| if (ret == -1) { | |||
| MS_LOG(ERROR) << "rsa public decrypt failed."; | |||
| MS_LOG(WARNING) << "rsa public decrypt failed."; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -405,9 +405,9 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned | |||
| if (ret != 1) { | |||
| uint64_t ulErr = ERR_get_error(); | |||
| char szErrMsg[1024] = {0}; | |||
| MS_LOG(ERROR) << "verify error. error number: " << ulErr; | |||
| MS_LOG(WARNING) << "verify WARNING. WARNING number: " << ulErr; | |||
| std::string str_res = ERR_error_string(ulErr, szErrMsg); | |||
| MS_LOG(ERROR) << szErrMsg; | |||
| MS_LOG(WARNING) << szErrMsg; | |||
| if (str_res.empty()) { | |||
| result = false; | |||
| break; | |||
| @@ -420,7 +420,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned | |||
| X509_free(keyAttestationCertObj); | |||
| CRYPTO_cleanup_all_ex_data(); | |||
| MS_LOG(INFO) << "verifyRSAKey success."; | |||
| MS_LOG(WARNING) << "verifyRSAKey end."; | |||
| return result; | |||
| } | |||
| @@ -445,7 +445,7 @@ void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash | |||
| std::string CertVerify::toHexString(const unsigned char *data, const int len) { | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "data hash is null."; | |||
| MS_LOG(WARNING) << "data hash is null."; | |||
| return ""; | |||
| } | |||
| @@ -465,10 +465,10 @@ bool CertVerify::verifyEquipCertAndFlID(const std::string &flID, const std::stri | |||
| sha256Hash(equipCert, hash, SHA256_DIGEST_LENGTH); | |||
| std::string equipCertSha256 = toHexString(hash, SHA256_DIGEST_LENGTH); | |||
| if (flID == equipCertSha256) { | |||
| MS_LOG(INFO) << "verifyEquipCertAndFlID success."; | |||
| MS_LOG(WARNING) << "verifyEquipCertAndFlID success."; | |||
| return true; | |||
| } else { | |||
| MS_LOG(ERROR) << "verifyEquipCertAndFlID failed."; | |||
| MS_LOG(WARNING) << "verifyEquipCertAndFlID failed."; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -482,13 +482,13 @@ bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &tim | |||
| return false; | |||
| } | |||
| int64_t now = tv.tv_sec * base + tv.tv_usec / base; | |||
| MS_LOG(INFO) << "flID: " << flID.c_str() << ",now time: " << now << ",requestTime: " << requestTime; | |||
| MS_LOG(WARNING) << "flID: " << flID.c_str() << ",now time: " << now << ",requestTime: " << requestTime; | |||
| int64_t diff = now - requestTime; | |||
| if (abs(diff) > replayAttackTimeDiff) { | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "verifyTimeStamp success."; | |||
| MS_LOG(WARNING) << "verifyTimeStamp success."; | |||
| return true; | |||
| } | |||
| @@ -514,7 +514,7 @@ void CertVerify::sha256Hash(const std::string &src, uint8_t *hash, const int len | |||
| bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData, | |||
| int srcDataLen) { | |||
| if (keyAttestation.empty() || signData == nullptr || srcData == nullptr || srcDataLen <= 0) { | |||
| MS_LOG(ERROR) << "keyAttestation or signData or srcData is invalid."; | |||
| MS_LOG(WARNING) << "keyAttestation or signData or srcData is invalid."; | |||
| return false; | |||
| } | |||
| bool result = true; | |||
| @@ -525,7 +525,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t * | |||
| pubKey = X509_get_pubkey(keyAttestationCertObj); | |||
| RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey); | |||
| if (pRSAPublicKey == nullptr) { | |||
| MS_LOG(ERROR) << "get rsa public key failed."; | |||
| MS_LOG(WARNING) << "get rsa public key failed."; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -534,7 +534,7 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t * | |||
| unsigned char buffer[256]; | |||
| int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING); | |||
| if (ret == -1) { | |||
| MS_LOG(ERROR) << "rsa public decrypt failed."; | |||
| MS_LOG(WARNING) << "rsa public decrypt failed."; | |||
| result = false; | |||
| break; | |||
| } | |||
| @@ -544,9 +544,9 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t * | |||
| if (ret != 1) { | |||
| uint64_t ulErr = ERR_get_error(); | |||
| char szErrMsg[1024] = {0}; | |||
| MS_LOG(ERROR) << "verify error. error number: " << ulErr; | |||
| MS_LOG(WARNING) << "verify WARNING. WARNING number: " << ulErr; | |||
| std::string str_res = ERR_error_string(ulErr, szErrMsg); | |||
| MS_LOG(ERROR) << szErrMsg; | |||
| MS_LOG(WARNING) << szErrMsg; | |||
| if (str_res.empty()) { | |||
| result = false; | |||
| break; | |||
| @@ -559,23 +559,23 @@ bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t * | |||
| X509_free(keyAttestationCertObj); | |||
| CRYPTO_cleanup_all_ex_data(); | |||
| MS_LOG(INFO) << "verifyRSAKey success."; | |||
| MS_LOG(WARNING) << "verifyRSAKey end."; | |||
| return result; | |||
| } | |||
| bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath, | |||
| const std::string equipCrlPath, const uint64_t replay_attack_time_diff) { | |||
| if (rootFirstCaFilePath.empty() || rootSecondCaFilePath.empty()) { | |||
| MS_LOG(ERROR) << "the root or crl path is empty."; | |||
| MS_LOG(WARNING) << "the root or crl path is empty."; | |||
| return false; | |||
| } | |||
| if (!checkFileExists(rootFirstCaFilePath)) { | |||
| MS_LOG(ERROR) << "The rootFirstCaFilePath is not exist."; | |||
| MS_LOG(WARNING) << "The rootFirstCaFilePath is not exist."; | |||
| return false; | |||
| } | |||
| if (!checkFileExists(rootSecondCaFilePath)) { | |||
| MS_LOG(ERROR) << "The rootSecondCaFilePath is not exist."; | |||
| MS_LOG(WARNING) << "The rootSecondCaFilePath is not exist."; | |||
| return false; | |||
| } | |||
| @@ -63,10 +63,10 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size | |||
| T *output_buff = reinterpret_cast<T *>(recvbuff); | |||
| uint32_t send_to_rank = (rank_id_ + 1) % rank_size; | |||
| uint32_t recv_from_rank = (rank_id_ - 1 + rank_size) % rank_size; | |||
| MS_LOG(INFO) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", rank_id_:" << rank_id_ | |||
| << ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size | |||
| << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank | |||
| << ", recv_from_rank:" << recv_from_rank; | |||
| MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", rank_id_:" << rank_id_ | |||
| << ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size | |||
| << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank | |||
| << ", recv_from_rank:" << recv_from_rank; | |||
| // Ring ReduceScatter. | |||
| MS_LOG(DEBUG) << "Start Ring ReduceScatter."; | |||
| @@ -148,8 +148,8 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec | |||
| MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); | |||
| uint32_t rank_size = server_num_; | |||
| MS_LOG(INFO) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", rank_id_:" << rank_id_ | |||
| << ", count:" << count; | |||
| MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", rank_id_:" << rank_id_ | |||
| << ", count:" << count; | |||
| size_t src_size = count * sizeof(T); | |||
| size_t dst_size = count * sizeof(T); | |||
| @@ -192,7 +192,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec | |||
| MS_LOG(DEBUG) << "End Reduce."; | |||
| // Broadcast data to not 0 rank process. | |||
| MS_LOG(INFO) << "Start broadcast from rank 0 to other processes."; | |||
| MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes."; | |||
| if (rank_id_ == 0) { | |||
| for (uint32_t i = 1; i < rank_size; i++) { | |||
| MS_LOG(DEBUG) << "Broadcast data to process " << i; | |||
| @@ -240,9 +240,9 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff | |||
| uint32_t send_to_rank = (rank_id_ + 1) % rank_size_; | |||
| uint32_t recv_from_rank = (rank_id_ - 1 + rank_size_) % rank_size_; | |||
| MS_LOG(INFO) << "Ring AllGather count:" << send_count << ", rank_size:" << rank_size_ << ", rank_id_:" << rank_id_ | |||
| << ", chunk_size:" << chunk_size << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank | |||
| << ", recv_from_rank:" << recv_from_rank; | |||
| MS_LOG(DEBUG) << "Ring AllGather count:" << send_count << ", rank_size:" << rank_size_ << ", rank_id_:" << rank_id_ | |||
| << ", chunk_size:" << chunk_size << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank | |||
| << ", recv_from_rank:" << recv_from_rank; | |||
| T *output_buff = reinterpret_cast<T *>(recvbuff); | |||
| size_t src_size = send_count * sizeof(T); | |||
| @@ -301,7 +301,7 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c | |||
| uint32_t global_root_rank = group_to_global_ranks[root]; | |||
| // Broadcast data to processes which are not the root. | |||
| MS_LOG(INFO) << "Start broadcast from root to other processes."; | |||
| MS_LOG(DEBUG) << "Start broadcast from root to other processes."; | |||
| if (rank_id_ == global_root_rank) { | |||
| for (uint32_t i = 1; i < group_rank_size; i++) { | |||
| uint32_t dst_rank = group_to_global_ranks[i]; | |||
| @@ -81,16 +81,16 @@ bool DistributedCountService::ReInitCounter(const std::string &name, size_t glob | |||
| } | |||
| bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) { | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; | |||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; | |||
| if (local_rank_ == counting_server_rank_) { | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Counter for " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "Counter for " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| if (global_current_count_[name].size() >= global_threshold_count_[name]) { | |||
| MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is " | |||
| MS_LOG(DEBUG) << "Count for " << name << " is already enough. Threshold count is " | |||
| << global_threshold_count_[name]; | |||
| return false; | |||
| } | |||
| @@ -98,7 +98,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; | |||
| (void)global_current_count_[name].insert(id); | |||
| if (!TriggerCounterEvent(name, reason)) { | |||
| MS_LOG(ERROR) << "Leader server trigger count event failed."; | |||
| MS_LOG(WARNING) << "Leader server trigger count event failed."; | |||
| return false; | |||
| } | |||
| } else { | |||
| @@ -110,7 +110,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount, | |||
| &report_cnt_rsp_msg)) { | |||
| MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name; | |||
| MS_LOG(WARNING) << "Sending reporting count message to leader server failed for " << name; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| @@ -121,7 +121,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| CountResponse count_rsp; | |||
| (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); | |||
| if (!count_rsp.result()) { | |||
| MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); | |||
| MS_LOG(WARNING) << "Reporting count failed:" << count_rsp.reason(); | |||
| // If the error is caused by the network issue, return the reason. | |||
| if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { | |||
| *reason = kNetworkError; | |||
| @@ -133,10 +133,10 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| } | |||
| bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; | |||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; | |||
| if (local_rank_ == counting_server_rank_) { | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Counter for " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "Counter for " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| @@ -149,7 +149,8 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||
| std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_, | |||
| ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { | |||
| MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; | |||
| MS_LOG(WARNING) << "Sending querying whether count reaches threshold message to leader server failed for " | |||
| << name; | |||
| return false; | |||
| } | |||
| @@ -202,10 +203,10 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core: | |||
| std::string reason = "Counter for " + name + " is not registered."; | |||
| count_rsp.set_result(false); | |||
| count_rsp.set_reason(reason); | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), | |||
| message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -217,10 +218,10 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core: | |||
| "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]); | |||
| count_rsp.set_result(false); | |||
| count_rsp.set_reason(reason); | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), | |||
| message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -239,7 +240,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core: | |||
| } | |||
| if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), | |||
| message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -254,7 +255,7 @@ void DistributedCountService::HandleCountReachThresholdRequest( | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Counter for " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "Counter for " << name << " is not registered."; | |||
| return; | |||
| } | |||
| @@ -262,7 +263,7 @@ void DistributedCountService::HandleCountReachThresholdRequest( | |||
| count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); | |||
| if (!communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(), | |||
| count_reach_threshold_rsp.SerializeAsString().size(), message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -274,7 +275,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core: | |||
| // callbacks. | |||
| std::string couter_event_rsp_msg = "success"; | |||
| if (!communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| @@ -284,7 +285,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core: | |||
| const auto &name = counter_event.name(); | |||
| if (counter_handlers_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The counter handler of " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "The counter handler of " << name << " is not registered."; | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name; | |||
| @@ -293,7 +294,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core: | |||
| } else if (type == CounterEventType::LAST_CNT) { | |||
| counter_handlers_[name].last_count_handler(message); | |||
| } else { | |||
| MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid."; | |||
| MS_LOG(WARNING) << "DistributedCountService event type " << type << " is invalid."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -301,7 +302,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core: | |||
| bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) { | |||
| if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The counter of " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "The counter of " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| @@ -331,7 +332,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st | |||
| for (uint32_t i = 1; i < server_num_; i++) { | |||
| MS_LOG(INFO) << "Start sending first count event message to server " << i; | |||
| if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | |||
| MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; | |||
| MS_LOG(WARNING) << "Activating first count event to server " << i << " failed."; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| @@ -340,7 +341,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st | |||
| } | |||
| if (counter_handlers_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The counter handler of " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "The counter handler of " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| // Leader server directly calls the callback. | |||
| @@ -360,7 +361,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std | |||
| for (uint32_t i = 1; i < server_num_; i++) { | |||
| MS_LOG(INFO) << "Start sending last count event message to server " << i; | |||
| if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { | |||
| MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; | |||
| MS_LOG(WARNING) << "Activating last count event to server " << i << " failed."; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| @@ -369,7 +370,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std | |||
| } | |||
| if (counter_handlers_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The counter handler of " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "The counter handler of " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| // Leader server directly calls the callback. | |||
| @@ -43,7 +43,7 @@ void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps: | |||
| void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; | |||
| return; | |||
| } | |||
| @@ -63,14 +63,14 @@ void DistributedMetadataStore::RegisterMetadata(const std::string &name, const P | |||
| void DistributedMetadataStore::ResetMetadata(const std::string &name) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; | |||
| return; | |||
| } | |||
| uint32_t stored_rank = router_->Find(name); | |||
| if (local_rank_ == stored_rank) { | |||
| if (metadata_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The metadata for " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "The metadata for " << name << " is not registered."; | |||
| return; | |||
| } | |||
| @@ -84,15 +84,15 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) { | |||
| bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; | |||
| return false; | |||
| } | |||
| uint32_t stored_rank = router_->Find(name); | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank; | |||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank; | |||
| if (local_rank_ == stored_rank) { | |||
| if (!DoUpdateMetadata(name, meta)) { | |||
| MS_LOG(ERROR) << "Updating meta data failed."; | |||
| MS_LOG(WARNING) << "Updating meta data failed."; | |||
| return false; | |||
| } | |||
| } else { | |||
| @@ -102,7 +102,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM | |||
| std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata, | |||
| &update_meta_rsp_msg)) { | |||
| MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed."; | |||
| MS_LOG(WARNING) << "Sending updating metadata message to server " << stored_rank << " failed."; | |||
| if (reason != nullptr) { | |||
| *reason = kNetworkError; | |||
| } | |||
| @@ -113,7 +113,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM | |||
| std::string update_meta_rsp = | |||
| std::string(reinterpret_cast<char *>(update_meta_rsp_msg->data()), update_meta_rsp_msg->size()); | |||
| if (update_meta_rsp != kSuccess) { | |||
| MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed. " << update_meta_rsp; | |||
| MS_LOG(WARNING) << "Updating metadata in server " << stored_rank << " failed. " << update_meta_rsp; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -122,11 +122,11 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM | |||
| PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; | |||
| return {}; | |||
| } | |||
| uint32_t stored_rank = router_->Find(name); | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; | |||
| MS_LOG(DEBUG) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; | |||
| if (local_rank_ == stored_rank) { | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| return metadata_[name]; | |||
| @@ -138,7 +138,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { | |||
| std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata, | |||
| &get_meta_rsp_msg)) { | |||
| MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; | |||
| MS_LOG(WARNING) << "Sending getting metadata message to server " << stored_rank << " failed."; | |||
| return get_metadata_rsp; | |||
| } | |||
| @@ -184,17 +184,17 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr | |||
| PBMetadataWithName meta_with_name; | |||
| (void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const std::string &name = meta_with_name.name(); | |||
| MS_LOG(INFO) << "Update metadata for " << name; | |||
| MS_LOG(DEBUG) << "Update metadata for " << name; | |||
| std::string update_meta_rsp_msg; | |||
| if (!DoUpdateMetadata(name, meta_with_name.metadata())) { | |||
| update_meta_rsp_msg = "Updating meta data failed."; | |||
| MS_LOG(ERROR) << update_meta_rsp_msg; | |||
| MS_LOG(WARNING) << update_meta_rsp_msg; | |||
| } else { | |||
| update_meta_rsp_msg = "Success"; | |||
| } | |||
| if (!communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -205,17 +205,17 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps | |||
| GetMetadataRequest get_metadata_req; | |||
| (void)get_metadata_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const std::string &name = get_metadata_req.name(); | |||
| MS_LOG(INFO) << "Getting metadata for " << name; | |||
| MS_LOG(DEBUG) << "Getting metadata for " << name; | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| if (metadata_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The metadata of " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "The metadata of " << name << " is not registered."; | |||
| return; | |||
| } | |||
| PBMetadata stored_meta = metadata_[name]; | |||
| std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); | |||
| if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -224,7 +224,7 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps | |||
| bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| if (metadata_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The metadata of " << name << " is not registered."; | |||
| MS_LOG(WARNING) << "The metadata of " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| if (meta.has_device_meta()) { | |||
| @@ -265,13 +265,13 @@ bool DistributedMetadataStore::DoUpdateEncryptMetadata(const std::string &name, | |||
| if (meta.has_pair_client_keys()) { | |||
| bool keys_update_succeed = UpdatePairClientKeys(name, meta); | |||
| if (!keys_update_succeed) { | |||
| MS_LOG(ERROR) << "Update pair_client_keys failed."; | |||
| MS_LOG(WARNING) << "Update pair_client_keys failed."; | |||
| return false; | |||
| } | |||
| } else if (meta.has_pair_client_shares()) { | |||
| bool shares_update_succeed = UpdatePairClientShares(name, meta); | |||
| if (!shares_update_succeed) { | |||
| MS_LOG(ERROR) << "Update pair_client_shares failed."; | |||
| MS_LOG(WARNING) << "Update pair_client_shares failed."; | |||
| return false; | |||
| } | |||
| } else if (meta.has_one_client_noises()) { | |||
| @@ -305,8 +305,8 @@ bool DistributedMetadataStore::DoUpdateEncryptMetadata(const std::string &name, | |||
| auto &certificate = meta.pair_key_attestation().certificate(); | |||
| key_attestation_map[fl_id] = certificate; | |||
| } else { | |||
| MS_LOG(ERROR) << "Leader server updating value for " << name | |||
| << " failed: The Protobuffer of this value is not defined."; | |||
| MS_LOG(WARNING) << "Leader server updating value for " << name | |||
| << " failed: The Protobuffer of this value is not defined."; | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -321,8 +321,8 @@ bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, con | |||
| for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) { | |||
| if (fl_id == iter->first) { | |||
| add_flag = false; | |||
| MS_LOG(ERROR) << "Leader server updating value for " << name | |||
| << " failed: The Protobuffer of this value already exists."; | |||
| MS_LOG(WARNING) << "Leader server updating value for " << name | |||
| << " failed: The Protobuffer of this value already exists."; | |||
| break; | |||
| } | |||
| } | |||
| @@ -344,8 +344,8 @@ bool DistributedMetadataStore::UpdatePairClientShares(const std::string &name, c | |||
| for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) { | |||
| if (fl_id == iter->first) { | |||
| add_flag = false; | |||
| MS_LOG(ERROR) << "Leader server updating value for " << name | |||
| << " failed: The Protobuffer of this value already exists."; | |||
| MS_LOG(WARNING) << "Leader server updating value for " << name | |||
| << " failed: The Protobuffer of this value already exists."; | |||
| break; | |||
| } | |||
| } | |||
| @@ -435,7 +435,7 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps | |||
| notify_leader_to_next_iter_rsp.set_result("success"); | |||
| if (!communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(), | |||
| notify_leader_to_next_iter_rsp.SerializeAsString().size(), message)) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| @@ -31,6 +31,8 @@ IterationTimer::~IterationTimer() { | |||
| } | |||
| void IterationTimer::Start(const std::chrono::milliseconds &duration) { | |||
| std::unique_lock<std::mutex> lock(timer_mtx_); | |||
| MS_LOG(INFO) << "The timer begin to start."; | |||
| if (running_.load()) { | |||
| MS_LOG(WARNING) << "The timer already started."; | |||
| return; | |||
| @@ -47,13 +49,17 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(1)); | |||
| } | |||
| }); | |||
| MS_LOG(INFO) << "The timer start success."; | |||
| } | |||
| void IterationTimer::Stop() { | |||
| std::unique_lock<std::mutex> lock(timer_mtx_); | |||
| MS_LOG(INFO) << "The timer begin to stop."; | |||
| running_ = false; | |||
| if (monitor_thread_.joinable()) { | |||
| monitor_thread_.join(); | |||
| } | |||
| MS_LOG(INFO) << "The timer stop success."; | |||
| } | |||
| void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) { | |||
| @@ -57,6 +57,8 @@ class IterationTimer { | |||
| // The thread that keeps timing and call timeout_callback_ when the timer expires. | |||
| std::thread monitor_thread_; | |||
| TimeOutCb timeout_callback_; | |||
| std::mutex timer_mtx_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace fl | |||
| @@ -135,9 +135,9 @@ class FedAvgKernel : public AggregationKernel { | |||
| ClearWeightAndDataSize(); | |||
| } | |||
| MS_LOG(INFO) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for " | |||
| << name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is " | |||
| << data_size_addr[0]; | |||
| MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for " | |||
| << name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is " | |||
| << data_size_addr[0]; | |||
| for (size_t i = 0; i < inputs[2]->size / sizeof(T); i++) { | |||
| weight_addr[i] += new_weight_addr[i]; | |||
| } | |||
| @@ -61,7 +61,7 @@ bool GetModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| retry_count_ += 1; | |||
| if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) { | |||
| MS_LOG(INFO) << "Launching GetModelKernel kernel. Retry count is " << retry_count_.load(); | |||
| MS_LOG(DEBUG) << "Launching GetModelKernel kernel. Retry count is " << retry_count_.load(); | |||
| } | |||
| const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data); | |||
| @@ -98,22 +98,22 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons | |||
| BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps, | |||
| std::to_string(next_req_time)); | |||
| if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) { | |||
| MS_LOG(WARNING) << reason; | |||
| MS_LOG(DEBUG) << reason; | |||
| } | |||
| return; | |||
| } | |||
| if (iter_to_model.count(get_model_iter) == 0) { | |||
| // If the model of get_model_iter is not stored, return the latest version of model and current iteration number. | |||
| MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter) | |||
| << " is invalid. Current iteration is " << std::to_string(current_iter); | |||
| MS_LOG(DEBUG) << "The iteration of GetModel request " << std::to_string(get_model_iter) | |||
| << " is invalid. Current iteration is " << std::to_string(current_iter); | |||
| feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num); | |||
| } else { | |||
| feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter); | |||
| } | |||
| IncreaseAcceptClientNum(); | |||
| MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() | |||
| << ", next request time is " << next_req_time << ", current iteration is " << current_iter; | |||
| MS_LOG(DEBUG) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() | |||
| << ", next request time is " << next_req_time << ", current iteration is " << current_iter; | |||
| BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter), | |||
| current_iter, feature_maps, std::to_string(next_req_time)); | |||
| return; | |||
| @@ -30,7 +30,7 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500; | |||
| constexpr uint32_t kPrintPullWeightForEveryRetryTime = 3000; | |||
| class PullWeightKernel : public RoundKernel { | |||
| public: | |||
| PullWeightKernel() : executor_(nullptr), retry_count_(0) {} | |||
| @@ -53,7 +53,7 @@ void StartFLJobKernel::InitKernel(size_t) { | |||
| bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| MS_LOG(INFO) << "Launching StartFLJobKernel kernel."; | |||
| MS_LOG(DEBUG) << "Launching StartFLJobKernel kernel."; | |||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||
| if (fbb == nullptr || req_data == nullptr) { | |||
| std::string reason = "FBBuilder builder or req_data is nullptr."; | |||
| @@ -66,7 +66,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| if (!verifier.VerifyBuffer<schema::RequestFLJob>()) { | |||
| std::string reason = "The schema of RequestFLJob is invalid."; | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, ""); | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| @@ -83,7 +83,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_RequestError, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| GenerateOutput(message, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| @@ -143,7 +143,7 @@ bool StartFLJobKernel::JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_RequestError, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| return false; | |||
| } | |||
| unsigned char sign_data[sign_data_vector->size()]; | |||
| @@ -172,9 +172,9 @@ bool StartFLJobKernel::JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_RequestError, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| } else { | |||
| MS_LOG(INFO) << "JudgeFLJobVerify success." << ret; | |||
| MS_LOG(DEBUG) << "JudgeFLJobVerify success." << ret; | |||
| } | |||
| return ret; | |||
| @@ -201,7 +201,7 @@ bool StartFLJobKernel::StoreKeyAttestation(const std::shared_ptr<FBBuilder> &fbb | |||
| bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientKeyAttestation, pb_data); | |||
| if (!ret) { | |||
| std::string reason = "startFLJob: store key attestation failed"; | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| @@ -232,7 +232,7 @@ ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<F | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| MS_LOG(DEBUG) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| @@ -246,7 +246,7 @@ DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *st | |||
| std::string fl_name = start_fl_job_req->fl_name()->str(); | |||
| std::string fl_id = start_fl_job_req->fl_id()->str(); | |||
| int data_size = start_fl_job_req->data_size(); | |||
| MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size; | |||
| MS_LOG(DEBUG) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size; | |||
| DeviceMeta device_meta; | |||
| device_meta.set_fl_name(fl_name); | |||
| @@ -266,7 +266,7 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(WARNING) << reason; | |||
| MS_LOG(DEBUG) << reason; | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -282,7 +282,7 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(ERROR) << reason; | |||
| MS_LOG(WARNING) << reason; | |||
| return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| @@ -305,7 +305,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, | |||
| const std::string &next_req_time, | |||
| std::map<std::string, AddressPtr> feature_maps) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(ERROR) << "Input fbb is nullptr."; | |||
| MS_LOG(WARNING) << "Input fbb is nullptr."; | |||
| return; | |||
| } | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| @@ -45,7 +45,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) { | |||
| bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| MS_LOG(INFO) << "Launching UpdateModelKernel kernel."; | |||
| MS_LOG(DEBUG) << "Launching UpdateModelKernel kernel."; | |||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||
| if (fbb == nullptr || req_data == nullptr) { | |||
| @@ -100,7 +100,8 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| MS_LOG(INFO) << "verify signature passed!"; | |||
| } | |||
| result_code = UpdateModel(update_model_req, fbb); | |||
| PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); | |||
| result_code = VerifyUpdateModel(update_model_req, fbb, device_metas); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| MS_LOG(ERROR) << "Updating model failed."; | |||
| GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| @@ -112,6 +113,14 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| } | |||
| result_code = UpdateModel(update_model_req, fbb, device_metas); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| MS_LOG(ERROR) << "Updating model failed."; | |||
| GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| } | |||
| IncreaseAcceptClientNum(); | |||
| GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| @@ -155,8 +164,8 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr | |||
| return ResultCode::kSuccess; | |||
| } | |||
| ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb) { | |||
| ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| size_t iteration = IntToSize(update_model_req->iteration()); | |||
| if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { | |||
| @@ -169,7 +178,6 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); | |||
| const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta(); | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id; | |||
| @@ -198,7 +206,15 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| } | |||
| return ResultCode::kSuccess; | |||
| } | |||
| ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta(); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kSuccessAndReturn); | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| size_t data_size = fl_id_to_meta.at(update_model_fl_id).data_size(); | |||
| const auto &feature_map = ParseFeatureMap(update_model_req); | |||
| if (feature_map.empty()) { | |||
| @@ -52,14 +52,16 @@ class UpdateModelKernel : public RoundKernel { | |||
| private: | |||
| ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb); | |||
| ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb); | |||
| ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb, | |||
| const PBMetadata &device_metas); | |||
| std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); | |||
| ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req); | |||
| sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req); | |||
| void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const std::string &next_req_time); | |||
| ResultCode VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas); | |||
| // The executor is for updating the model for updateModel request. | |||
| Executor *executor_{nullptr}; | |||
| @@ -43,7 +43,7 @@ class MemoryRegister { | |||
| // avoid its data being released. | |||
| template <typename T> | |||
| void RegisterArray(const std::string &name, std::unique_ptr<T[]> *array, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(array); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(array); | |||
| void *data = array->get(); | |||
| AddressPtr addr = std::make_shared<Address>(); | |||
| addr->addr = data; | |||
| @@ -62,7 +62,7 @@ class MemoryRegister { | |||
| auto char_arr = CastUniquePtr<char, T>(array); | |||
| StoreCharArray(&char_arr); | |||
| } else { | |||
| MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); | |||
| MS_LOG(WARNING) << "MemoryRegister does not support type " << typeid(T).name(); | |||
| return; | |||
| } | |||
| @@ -90,7 +90,7 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration | |||
| std::unique_lock<std::mutex> lock(model_mtx_); | |||
| std::map<std::string, AddressPtr> model = {}; | |||
| if (iteration_to_model_.count(iteration) == 0) { | |||
| MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored."; | |||
| MS_LOG(WARNING) << "Model for iteration " << iteration << " is not stored. Return latest model"; | |||
| return model; | |||
| } | |||
| model = iteration_to_model_[iteration]->addresses(); | |||
| @@ -114,8 +114,8 @@ size_t ModelStore::model_size() const { return model_size_; } | |||
| std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() { | |||
| std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel(); | |||
| if (model.empty()) { | |||
| MS_LOG(EXCEPTION) << "Model feature map is empty."; | |||
| return nullptr; | |||
| MS_LOG(WARNING) << "Model feature map is empty."; | |||
| return std::make_shared<MemoryRegister>(); | |||
| } | |||
| // Assign new memory for the model. | |||
| @@ -25,6 +25,8 @@ namespace fl { | |||
| namespace server { | |||
| class Server; | |||
| class Iteration; | |||
| std::atomic<uint32_t> kPrintTimes = 0; | |||
| const uint32_t kPrintTimesThreshold = 3000; | |||
| Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count, | |||
| bool server_num_as_threshold) | |||
| : name_(name), | |||
| @@ -99,7 +101,7 @@ bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t | |||
| threshold_count_ = updated_threshold_count; | |||
| if (check_count_) { | |||
| if (!DistributedCountService::GetInstance().ReInitCounter(name_, threshold_count_)) { | |||
| MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed."; | |||
| MS_LOG(WARNING) << "Reinitializing count for " << name_ << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -124,7 +126,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m | |||
| std::string reason = ""; | |||
| if (!IsServerAvailable(&reason)) { | |||
| if (!message->SendResponse(reason.c_str(), reason.size())) { | |||
| MS_LOG(ERROR) << "Sending response failed."; | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| @@ -198,7 +200,11 @@ bool Round::IsServerAvailable(std::string *reason) { | |||
| // If the server is still in safemode, reject the request. | |||
| if (Server::GetInstance().IsSafeMode()) { | |||
| MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later."; | |||
| if (kPrintTimes % kPrintTimesThreshold == 0) { | |||
| MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later."; | |||
| kPrintTimes = 0; | |||
| } | |||
| kPrintTimes += 1; | |||
| *reason = ps::kClusterSafeMode; | |||
| return false; | |||
| } | |||
| @@ -22,6 +22,7 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| bool ServerRecovery::Initialize(const std::string &config_file) { | |||
| std::unique_lock<std::mutex> lock(server_recovery_file_mtx_); | |||
| config_ = std::make_unique<ps::core::FileConfiguration>(config_file); | |||
| MS_EXCEPTION_IF_NULL(config_); | |||
| if (!config_->Initialize()) { | |||
| @@ -59,6 +60,7 @@ bool ServerRecovery::Initialize(const std::string &config_file) { | |||
| } | |||
| bool ServerRecovery::Recover() { | |||
| std::unique_lock<std::mutex> lock(server_recovery_file_mtx_); | |||
| server_recovery_file_.open(server_recovery_file_path_, std::ios::in); | |||
| if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) { | |||
| MS_LOG(WARNING) << "Can't open server recovery file " << server_recovery_file_path_; | |||
| @@ -80,6 +82,7 @@ bool ServerRecovery::Recover() { | |||
| } | |||
| bool ServerRecovery::Save(uint64_t current_iter) { | |||
| std::unique_lock<std::mutex> lock(server_recovery_file_mtx_); | |||
| server_recovery_file_.open(server_recovery_file_path_, std::ios::out | std::ios::ate); | |||
| if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) { | |||
| MS_LOG(WARNING) << "Can't save data to recovery file " << server_recovery_file_path_ | |||
| @@ -96,6 +99,7 @@ bool ServerRecovery::Save(uint64_t current_iter) { | |||
| bool ServerRecovery::SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator, | |||
| uint32_t rank_id) { | |||
| std::unique_lock<std::mutex> lock(server_recovery_file_mtx_); | |||
| // If this server is follower server, notify leader server that this server has recovered. | |||
| if (rank_id != kLeaderServerRank) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(communicator, false); | |||
| @@ -22,6 +22,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include "ps/core/recovery_base.h" | |||
| #include "ps/core/file_configuration.h" | |||
| #include "ps/core/communicator/tcp_communicator.h" | |||
| @@ -58,6 +59,7 @@ class ServerRecovery : public ps::core::RecoveryBase { | |||
| // The server recovery file object. | |||
| std::fstream server_recovery_file_; | |||
| std::mutex server_recovery_file_mtx_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace fl | |||
| @@ -669,7 +669,6 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con | |||
| std::lock_guard<std::mutex> lock(client_mutex_); | |||
| connected_nodes_.clear(); | |||
| PersistMetaData(); | |||
| } | |||
| void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, | |||
| @@ -696,7 +695,6 @@ void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &con | |||
| } | |||
| is_ready_ = true; | |||
| UpdateClusterState(ClusterState::CLUSTER_READY); | |||
| PersistMetaData(); | |||
| } | |||
| void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, | |||
| @@ -710,7 +708,6 @@ void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn | |||
| } | |||
| is_ready_ = true; | |||
| UpdateClusterState(ClusterState::CLUSTER_READY); | |||
| PersistMetaData(); | |||
| } | |||
| void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, | |||
| @@ -860,8 +857,7 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { | |||
| void AbstractNode::InitClientToServer() { | |||
| // create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed | |||
| client_to_server_ = | |||
| std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get(), node_info_.node_role_); | |||
| client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, node_info_.node_role_); | |||
| MS_EXCEPTION_IF_NULL(client_to_server_); | |||
| client_to_server_->Init(); | |||
| MS_LOG(INFO) << "The node start a tcp client to this node!"; | |||
| @@ -872,8 +868,7 @@ bool AbstractNode::InitClientToScheduler() { | |||
| MS_LOG(WARNING) << "The config is empty."; | |||
| return false; | |||
| } | |||
| client_to_scheduler_ = | |||
| std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get(), NodeRole::SCHEDULER); | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, NodeRole::SCHEDULER); | |||
| MS_EXCEPTION_IF_NULL(client_to_scheduler_); | |||
| client_to_scheduler_->SetMessageCallback( | |||
| [&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) { | |||
| @@ -931,7 +926,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3 | |||
| MS_LOG(INFO) << "Create tcp client for role: " << role << ", rank: " << rank_id; | |||
| std::string ip = nodes_address_[key].first; | |||
| uint16_t port = nodes_address_[key].second; | |||
| auto client = std::make_shared<TcpClient>(ip, port, config_.get(), role); | |||
| auto client = std::make_shared<TcpClient>(ip, port, role); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, | |||
| size_t size) { | |||
| @@ -37,15 +37,14 @@ event_base *TcpClient::event_base_ = nullptr; | |||
| std::mutex TcpClient::event_base_mutex_; | |||
| bool TcpClient::is_started_ = false; | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role) | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port, NodeRole peer_role) | |||
| : event_timeout_(nullptr), | |||
| buffer_event_(nullptr), | |||
| server_address_(std::move(address)), | |||
| server_port_(port), | |||
| peer_role_(peer_role), | |||
| is_stop_(true), | |||
| is_connected_(false), | |||
| config_(config) { | |||
| is_connected_(false) { | |||
| message_handler_.SetCallback( | |||
| [this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) { | |||
| if (message_callback_) { | |||
| @@ -54,7 +54,7 @@ class TcpClient { | |||
| std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>; | |||
| using OnTimer = std::function<void()>; | |||
| explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role); | |||
| explicit TcpClient(const std::string &address, std::uint16_t port, NodeRole peer_role); | |||
| virtual ~TcpClient(); | |||
| std::string GetServerAddress() const; | |||
| @@ -239,6 +239,7 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConn | |||
| void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Remove connection fd: " << fd; | |||
| connections_.erase(fd); | |||
| } | |||
| @@ -383,7 +384,7 @@ void TcpServer::EventCallbackInner(struct bufferevent *bev, std::int16_t events, | |||
| MS_EXCEPTION_IF_NULL(srv); | |||
| if (events & BEV_EVENT_EOF) { | |||
| MS_LOG(INFO) << "Event buffer end of file, a client is disconnected from this server!"; | |||
| MS_LOG(INFO) << "BEV_EVENT_EOF event is trigger!"; | |||
| // Notify about disconnection | |||
| if (srv->client_disconnection_) { | |||
| srv->client_disconnection_(*srv, *conn); | |||
| @@ -391,7 +392,7 @@ void TcpServer::EventCallbackInner(struct bufferevent *bev, std::int16_t events, | |||
| // Free connection structures | |||
| srv->RemoveConnection(conn->GetFd()); | |||
| } else if (events & BEV_EVENT_ERROR) { | |||
| MS_LOG(WARNING) << "Connect to server error."; | |||
| MS_LOG(WARNING) << "BEV_EVENT_ERROR event is trigger!"; | |||
| if (PSContext::instance()->enable_ssl()) { | |||
| uint64_t err = bufferevent_get_openssl_error(bev); | |||
| MS_LOG(WARNING) << "The error number is:" << err; | |||
| @@ -80,8 +80,8 @@ bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std:: | |||
| if (!client->SendMessage(meta, protos, data, size)) { | |||
| MS_LOG(WARNING) << "Client send message failed."; | |||
| } | |||
| MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; | |||
| return Wait(request_id, timeout); | |||
| } | |||
| @@ -20,6 +20,7 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| bool NodeRecovery::Recover() { | |||
| std::unique_lock<std::mutex> lock(recovery_mtx_); | |||
| if (recovery_storage_ == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -20,6 +20,7 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| bool RecoveryBase::Initialize(const std::string &config_json) { | |||
| std::unique_lock<std::mutex> lock(recovery_mtx_); | |||
| nlohmann::json recovery_config; | |||
| try { | |||
| recovery_config = nlohmann::json::parse(config_json); | |||
| @@ -87,7 +88,8 @@ bool RecoveryBase::InitializeNodes(const std::string &config_json) { | |||
| return true; | |||
| } | |||
| void RecoveryBase::Persist(const core::ClusterConfig &clusterConfig) const { | |||
| void RecoveryBase::Persist(const core::ClusterConfig &clusterConfig) { | |||
| std::unique_lock<std::mutex> lock(recovery_mtx_); | |||
| if (recovery_storage_ == nullptr) { | |||
| MS_LOG(WARNING) << "recovery storage is null, so don't persist meta data"; | |||
| return; | |||
| @@ -95,7 +97,8 @@ void RecoveryBase::Persist(const core::ClusterConfig &clusterConfig) const { | |||
| recovery_storage_->PersistFile(clusterConfig); | |||
| } | |||
| void RecoveryBase::PersistNodesInfo(const core::ClusterConfig &clusterConfig) const { | |||
| void RecoveryBase::PersistNodesInfo(const core::ClusterConfig &clusterConfig) { | |||
| std::unique_lock<std::mutex> lock(recovery_mtx_); | |||
| if (scheduler_recovery_storage_ == nullptr) { | |||
| MS_LOG(WARNING) << "scheduler recovery storage is null, so don't persist nodes meta data"; | |||
| return; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include "ps/constants.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -50,10 +51,10 @@ class RecoveryBase { | |||
| virtual bool Recover() = 0; | |||
| // Persist metadata to storage. | |||
| virtual void Persist(const core::ClusterConfig &clusterConfig) const; | |||
| virtual void Persist(const core::ClusterConfig &clusterConfig); | |||
| // Persist metadata to storage. | |||
| virtual void PersistNodesInfo(const core::ClusterConfig &clusterConfig) const; | |||
| virtual void PersistNodesInfo(const core::ClusterConfig &clusterConfig); | |||
| protected: | |||
| // Persistent storage used to save metadata. | |||
| @@ -64,6 +65,8 @@ class RecoveryBase { | |||
| // Storage type for recovery,Currently only supports storage of file types | |||
| StorageType storage_type_; | |||
| std::mutex recovery_mtx_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -67,8 +67,8 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| void SchedulerNode::RunRecovery() { | |||
| core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); | |||
| // create tcp client to myself in case of event dispatch failed when Send reconnect msg to server failed | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port, | |||
| config_.get(), NodeRole::SCHEDULER); | |||
| client_to_scheduler_ = | |||
| std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port, NodeRole::SCHEDULER); | |||
| MS_EXCEPTION_IF_NULL(client_to_scheduler_); | |||
| client_to_scheduler_->Init(); | |||
| client_thread_ = std::make_unique<std::thread>([this]() { | |||
| @@ -95,7 +95,7 @@ void SchedulerNode::RunRecovery() { | |||
| for (const auto &kvs : initial_node_infos) { | |||
| auto &node_id = kvs.first; | |||
| auto &node_info = kvs.second; | |||
| auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, config_.get(), node_info.node_role_); | |||
| auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, node_info.node_role_); | |||
| client->SetMessageCallback([this](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *, size_t) { | |||
| MS_LOG(INFO) << "received the response. "; | |||
| NotifyMessageArrival(meta); | |||
| @@ -647,7 +647,7 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf | |||
| std::string ip = node_info.ip_; | |||
| uint16_t port = node_info.port_; | |||
| MS_LOG(INFO) << "ip:" << ip << ", port:" << port << ", node id:" << node_info.node_id_; | |||
| auto client = std::make_shared<TcpClient>(ip, port, config_.get(), node_info.node_role_); | |||
| auto client = std::make_shared<TcpClient>(ip, port, node_info.node_role_); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| client->SetMessageCallback( | |||
| [&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) { | |||
| @@ -20,11 +20,13 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| std::string SchedulerRecovery::GetMetadata(const std::string &key) { | |||
| std::unique_lock<std::mutex> lock(recovery_mtx_); | |||
| MS_EXCEPTION_IF_NULL(recovery_storage_); | |||
| return recovery_storage_->Get(key, ""); | |||
| } | |||
| bool SchedulerRecovery::Recover() { | |||
| std::unique_lock<std::mutex> lock(recovery_mtx_); | |||
| if (recovery_storage_ == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -28,8 +28,7 @@ class TestTcpClient : public UT::Common { | |||
| }; | |||
| TEST_F(TestTcpClient, InitClientIPError) { | |||
| std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>(""); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, config.get(), NodeRole::SERVER); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, NodeRole::SERVER); | |||
| client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) { | |||
| CommMessage message; | |||
| @@ -42,8 +41,7 @@ TEST_F(TestTcpClient, InitClientIPError) { | |||
| } | |||
| TEST_F(TestTcpClient, InitClientPortErrorNoException) { | |||
| std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>(""); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.1", -1, config.get(), NodeRole::SERVER); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.1", -1, NodeRole::SERVER); | |||
| client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) { | |||
| CommMessage message; | |||
| @@ -60,7 +60,7 @@ class TestTcpServer : public UT::Common { | |||
| TEST_F(TestTcpServer, ServerSendMessage) { | |||
| std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>(""); | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), config.get(), NodeRole::SERVER); | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), NodeRole::SERVER); | |||
| std::cout << server_->BoundPort() << std::endl; | |||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | |||
| http_client_thread = std::make_unique<std::thread>([&]() { | |||