| @@ -27,7 +27,11 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size | |||
| size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt, | |||
| size_t cipher_reconstruct_secrets_up_cnt) { | |||
| MS_LOG(INFO) << "CipherInit::Init START"; | |||
| if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN) != 0) { | |||
| if (publicparam_.p == nullptr || param.p == nullptr || param.prime == nullptr || publicparam_.prime == nullptr) { | |||
| MS_LOG(ERROR) << "CipherInit::input data invalid."; | |||
| return false; | |||
| } | |||
| if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, sizeof(param.p)) != 0) { | |||
| MS_LOG(ERROR) << "CipherInit::memory copy failed."; | |||
| return false; | |||
| } | |||
| @@ -43,25 +43,28 @@ class CipherInit { | |||
| size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt, | |||
| size_t cipher_reconstruct_secrets_up_cnt); | |||
| // Check whether the parameters are valid. | |||
| bool Check_Parames(); | |||
| // Get public params. which is given to start fl job thread. | |||
| CipherPublicPara *GetPublicParams() { return &publicparam_; } | |||
| size_t share_secrets_threshold; // the minimum number of clients to share secret fragments. | |||
| size_t get_secrets_threshold; // the minimum number of clients to get secret fragments. | |||
| size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask. | |||
| size_t exchange_key_threshold; // the minimum number of clients to send public keys. | |||
| size_t get_key_threshold; // the minimum number of clients to get public keys. | |||
| size_t client_list_threshold; // the minimum number of clients to get update model client list. | |||
| size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask. | |||
| size_t featuremap_; // the size of data to deal. | |||
| size_t time_out_mutex_; // timeout mutex. | |||
| size_t push_list_sign_threshold; // the minimum number of clients to push client list signature. | |||
| size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask. | |||
| size_t featuremap_; // the size of data to deal. | |||
| CipherPublicPara publicparam_; // the param containing encrypted public parameters. | |||
| CipherMetaStorage cipher_meta_storage_; | |||
| private: | |||
| size_t client_list_threshold; // the minimum number of clients to get update model client list. | |||
| size_t get_key_threshold; // the minimum number of clients to get public keys. | |||
| size_t get_list_sign_threshold; // the minimum number of clients to get client list signature. | |||
| size_t get_secrets_threshold; // the minimum number of clients to get secret fragments. | |||
| size_t time_out_mutex_; // timeout mutex. | |||
| // Check whether the parameters are valid. | |||
| bool Check_Parames(); | |||
| }; | |||
| } // namespace armour | |||
| } // namespace mindspore | |||
| @@ -19,12 +19,16 @@ | |||
| namespace mindspore { | |||
| namespace armour { | |||
| bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time, | |||
| bool CipherKeys::GetKeys(const size_t cur_iterator, const std::string &next_req_time, | |||
| const schema::GetExchangeKeys *get_exchange_keys_req, | |||
| const std::shared_ptr<fl::server::FBBuilder> &fbb) { | |||
| MS_LOG(INFO) << "CipherMgr::GetKeys START"; | |||
| if (get_exchange_keys_req == nullptr) { | |||
| MS_LOG(ERROR) << "Request is nullptr"; | |||
| BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false); | |||
| return false; | |||
| } | |||
| if (cipher_init_ == nullptr) { | |||
| BuildGetKeysRsp(fbb, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false); | |||
| return false; | |||
| } | |||
| @@ -61,7 +65,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim | |||
| return true; | |||
| } | |||
| bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time, | |||
| bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next_req_time, | |||
| const schema::RequestExchangeKeys *exchange_keys_req, | |||
| const std::shared_ptr<fl::server::FBBuilder> &fbb) { | |||
| MS_LOG(INFO) << "CipherMgr::ExchangeKeys START"; | |||
| @@ -72,6 +76,11 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re | |||
| BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, next_req_time, cur_iterator); | |||
| return false; | |||
| } | |||
| if (cipher_init_ == nullptr) { | |||
| std::string reason = "cipher_init_ is nullptr"; | |||
| BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, reason, next_req_time, cur_iterator); | |||
| return false; | |||
| } | |||
| std::string fl_id = exchange_keys_req->fl_id()->str(); | |||
| mindspore::fl::PBMetadata device_metas = | |||
| fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxDeviceMetas); | |||
| @@ -127,9 +136,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re | |||
| } | |||
| } | |||
| void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const std::string &next_req_time, | |||
| const int iteration) { | |||
| void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, | |||
| const schema::ResponseCode retcode, const std::string &reason, | |||
| const std::string &next_req_time, const size_t iteration) { | |||
| auto rsp_reason = fbb->CreateString(reason); | |||
| auto rsp_next_req_time = fbb->CreateString(next_req_time); | |||
| @@ -137,21 +146,21 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb | |||
| rsp_builder.add_retcode(retcode); | |||
| rsp_builder.add_reason(rsp_reason); | |||
| rsp_builder.add_next_req_time(rsp_next_req_time); | |||
| rsp_builder.add_iteration(iteration); | |||
| rsp_builder.add_iteration(SizeToInt(iteration)); | |||
| auto rsp_exchange_keys = rsp_builder.Finish(); | |||
| fbb->Finish(rsp_exchange_keys); | |||
| return; | |||
| } | |||
| void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, | |||
| const int iteration, const std::string &next_req_time, bool is_good) { | |||
| void CipherKeys::BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const size_t iteration, const std::string &next_req_time, bool is_good) { | |||
| if (!is_good) { | |||
| auto fbs_next_req_time = fbb->CreateString(next_req_time); | |||
| schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get())); | |||
| rsp_buider.add_retcode(retcode); | |||
| rsp_buider.add_iteration(iteration); | |||
| rsp_buider.add_next_req_time(fbs_next_req_time); | |||
| auto rsp_get_keys = rsp_buider.Finish(); | |||
| schema::ReturnExchangeKeysBuilder rsp_builder(*(fbb.get())); | |||
| rsp_builder.add_retcode(static_cast<int>(retcode)); | |||
| rsp_builder.add_iteration(SizeToInt(iteration)); | |||
| rsp_builder.add_next_req_time(fbs_next_req_time); | |||
| auto rsp_get_keys = rsp_builder.Finish(); | |||
| fbb->Finish(rsp_get_keys); | |||
| return; | |||
| } | |||
| @@ -176,12 +185,12 @@ void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, con | |||
| } | |||
| auto remote_publickeys = fbb->CreateVector(public_keys_list); | |||
| auto fbs_next_req_time = fbb->CreateString(next_req_time); | |||
| schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get())); | |||
| rsp_buider.add_retcode(retcode); | |||
| rsp_buider.add_iteration(iteration); | |||
| rsp_buider.add_remote_publickeys(remote_publickeys); | |||
| rsp_buider.add_next_req_time(fbs_next_req_time); | |||
| auto rsp_get_keys = rsp_buider.Finish(); | |||
| schema::ReturnExchangeKeysBuilder rsp_builder(*(fbb.get())); | |||
| rsp_builder.add_retcode(static_cast<int>(retcode)); | |||
| rsp_builder.add_iteration(SizeToInt(iteration)); | |||
| rsp_builder.add_remote_publickeys(remote_publickeys); | |||
| rsp_builder.add_next_req_time(fbs_next_req_time); | |||
| auto rsp_get_keys = rsp_builder.Finish(); | |||
| fbb->Finish(rsp_get_keys); | |||
| MS_LOG(INFO) << "CipherMgr::GetKeys Success"; | |||
| return; | |||
| @@ -190,6 +199,7 @@ void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, con | |||
| void CipherKeys::ClearKeys() { | |||
| fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxExChangeKeysClientList); | |||
| fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsKeys); | |||
| fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxGetKeysClientList); | |||
| } | |||
| } // namespace armour | |||
| @@ -36,6 +36,7 @@ class CipherKeys { | |||
| public: | |||
| // initialize: get cipher_init_ | |||
| CipherKeys() { cipher_init_ = &CipherInit::GetInstance(); } | |||
| ~CipherKeys() = default; | |||
| static CipherKeys &GetInstance() { | |||
| static CipherKeys instance; | |||
| @@ -43,20 +44,20 @@ class CipherKeys { | |||
| } | |||
| // handle the client's request of get keys. | |||
| bool GetKeys(const int cur_iterator, const std::string &next_req_time, | |||
| bool GetKeys(const size_t cur_iterator, const std::string &next_req_time, | |||
| const schema::GetExchangeKeys *get_exchange_keys_req, const std::shared_ptr<fl::server::FBBuilder> &fbb); | |||
| // handle the client's request of exchange keys. | |||
| bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time, | |||
| bool ExchangeKeys(const size_t cur_iterator, const std::string &next_req_time, | |||
| const schema::RequestExchangeKeys *exchange_keys_req, | |||
| const std::shared_ptr<fl::server::FBBuilder> &fbb); | |||
| // build response code of get keys. | |||
| void BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, | |||
| const int iteration, const std::string &next_req_time, bool is_good); | |||
| void BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const size_t iteration, const std::string &next_req_time, bool is_good); | |||
| // build response code of exchange keys. | |||
| void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const std::string &next_req_time, const int iteration); | |||
| void BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const std::string &next_req_time, const size_t iteration); | |||
| // clear the shared memory. | |||
| void ClearKeys(); | |||
| @@ -21,6 +21,10 @@ namespace armour { | |||
| void CipherMetaStorage::GetClientSharesFromServer( | |||
| const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) { | |||
| if (clients_shares_list == nullptr) { | |||
| MS_LOG(ERROR) << "input clients_shares_list is nullptr"; | |||
| return; | |||
| } | |||
| const fl::PBMetadata &clients_shares_pb_out = | |||
| fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); | |||
| const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares(); | |||
| @@ -29,7 +33,8 @@ void CipherMetaStorage::GetClientSharesFromServer( | |||
| std::string fl_id = iter->first; | |||
| const fl::SharesPb &shares_pb = iter->second; | |||
| std::vector<clientshare_str> encrpted_shares_new; | |||
| for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) { | |||
| size_t client_share_num = IntToSize(shares_pb.clientsharestrs_size()); | |||
| for (size_t index_shares = 0; index_shares < client_share_num; ++index_shares) { | |||
| const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares); | |||
| clientshare_str new_clientshare; | |||
| new_clientshare.fl_id = client_share_str_pb.fl_id(); | |||
| @@ -42,9 +47,14 @@ void CipherMetaStorage::GetClientSharesFromServer( | |||
| } | |||
| void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) { | |||
| if (clients_list == nullptr) { | |||
| MS_LOG(ERROR) << "input clients_list is nullptr"; | |||
| return; | |||
| } | |||
| const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); | |||
| const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list(); | |||
| for (int i = 0; i < client_list_pb.fl_id_size(); ++i) { | |||
| size_t client_list_num = IntToSize(client_list_pb.fl_id_size()); | |||
| for (size_t i = 0; i < client_list_num; ++i) { | |||
| std::string fl_id = client_list_pb.fl_id(i); | |||
| clients_list->push_back(fl_id); | |||
| } | |||
| @@ -52,6 +62,10 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect | |||
| void CipherMetaStorage::GetClientKeysFromServer( | |||
| const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) { | |||
| if (clients_keys_list == nullptr) { | |||
| MS_LOG(ERROR) << "input clients_keys_list is nullptr"; | |||
| return; | |||
| } | |||
| const fl::PBMetadata &clients_keys_pb_out = | |||
| fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); | |||
| const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys(); | |||
| @@ -65,12 +79,16 @@ void CipherMetaStorage::GetClientKeysFromServer( | |||
| std::vector<std::vector<uint8_t>> cur_keys; | |||
| cur_keys.push_back(cpk); | |||
| cur_keys.push_back(spk); | |||
| clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys)); | |||
| (void)clients_keys_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys)); | |||
| } | |||
| } | |||
| void CipherMetaStorage::GetClientIVsFromServer( | |||
| const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) { | |||
| if (clients_ivs_list == nullptr) { | |||
| MS_LOG(ERROR) << "input clients_ivs_list is nullptr"; | |||
| return; | |||
| } | |||
| const fl::PBMetadata &clients_keys_pb_out = | |||
| fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); | |||
| const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys(); | |||
| @@ -86,25 +104,28 @@ void CipherMetaStorage::GetClientIVsFromServer( | |||
| cur_ivs.push_back(ind_iv); | |||
| cur_ivs.push_back(pw_iv); | |||
| cur_ivs.push_back(pw_salt); | |||
| clients_ivs_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs)); | |||
| (void)clients_ivs_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs)); | |||
| } | |||
| } | |||
| bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) { | |||
| if (cur_public_noise == nullptr) { | |||
| MS_LOG(ERROR) << "input cur_public_noise is nullptr"; | |||
| return false; | |||
| } | |||
| const fl::PBMetadata &clients_noises_pb_out = | |||
| fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); | |||
| const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises(); | |||
| int count = 0; | |||
| int count_thld = 100; | |||
| const int count_thld = 1000; | |||
| while (clients_noises_pb.has_one_client_noises() == false) { | |||
| int register_time = 500; | |||
| const int register_time = 500; | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(register_time)); | |||
| count++; | |||
| if (count >= count_thld) break; | |||
| } | |||
| MS_LOG(INFO) << "GetClientNoisesFromServer Count: " << count; | |||
| if (clients_noises_pb.has_one_client_noises() == false) { | |||
| MS_LOG(ERROR) << "GetClientNoisesFromServer NULL."; | |||
| MS_LOG(WARNING) << "GetClientNoisesFromServer Count: " << count; | |||
| return false; | |||
| } | |||
| cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(), | |||
| @@ -112,17 +133,22 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve | |||
| return true; | |||
| } | |||
| bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) { | |||
| bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, uint8_t *prime) { | |||
| if (prime == nullptr) { | |||
| MS_LOG(ERROR) << "input prime is nullptr"; | |||
| return false; | |||
| } | |||
| const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name); | |||
| fl::Prime prime_pb(prime_pb_out.prime()); | |||
| std::string str = *(prime_pb.mutable_prime()); | |||
| MS_LOG(INFO) << "get prime from metastorage :" << str; | |||
| if (str.size() != PRIME_MAX_LEN) { | |||
| MS_LOG(ERROR) << "get prime size is :" << str.size(); | |||
| return false; | |||
| } else { | |||
| memcpy_s(prime, PRIME_MAX_LEN, str.data(), PRIME_MAX_LEN); | |||
| if (memcpy_s(prime, PRIME_MAX_LEN, str.data(), str.size()) != 0) { | |||
| MS_LOG(ERROR) << "Memcpy_s error"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| @@ -143,12 +169,13 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string & | |||
| fl::PBMetadata prime_pb; | |||
| prime_pb.mutable_prime()->MergeFrom(prime_id_pb); | |||
| fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb); | |||
| sleep(1); | |||
| uint32_t time = 1; | |||
| (void)sleep(time); | |||
| } | |||
| bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id, | |||
| const std::vector<std::vector<uint8_t>> &cur_public_key) { | |||
| size_t correct_size = 2; | |||
| const size_t correct_size = 2; | |||
| if (cur_public_key.size() < correct_size) { | |||
| MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size(); | |||
| return false; | |||
| @@ -245,14 +272,18 @@ bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const s | |||
| bool CipherMetaStorage::UpdateClientShareToServer( | |||
| const char *list_name, const std::string &fl_id, | |||
| const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) { | |||
| int size_shares = shares->size(); | |||
| if (shares == nullptr) { | |||
| return false; | |||
| } | |||
| size_t size_shares = shares->size(); | |||
| fl::SharesPb shares_pb; | |||
| for (int index = 0; index < size_shares; ++index) { | |||
| for (size_t index = 0; index < size_shares; ++index) { | |||
| // new item | |||
| fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs(); | |||
| std::string fl_id_new = (*shares)[index]->fl_id()->str(); | |||
| int index_new = (*shares)[index]->index(); | |||
| auto share = (*shares)[index]->share(); | |||
| if (share == nullptr) return false; | |||
| client_share_str_new_p->set_share(reinterpret_cast<const char *>(share->data()), share->size()); | |||
| client_share_str_new_p->set_fl_id(fl_id_new); | |||
| client_share_str_new_p->set_index(index_new); | |||
| @@ -293,6 +324,8 @@ void CipherMetaStorage::RegisterClass() { | |||
| fl::PBMetadata get_update_clients_list; | |||
| fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList, | |||
| get_update_clients_list); | |||
| fl::PBMetadata client_noises; | |||
| fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientNoises, client_noises); | |||
| } | |||
| } // namespace armour | |||
| } // namespace mindspore | |||
| @@ -58,8 +58,8 @@ bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp, | |||
| for (int i = 0; i < static_cast<int>(cipher_init_->secrets_minnums_); ++i) { | |||
| shares_tmp->at(i)->index = (iter->second)[i].index; | |||
| shares_tmp->at(i)->len = (iter->second)[i].share.size(); | |||
| if (memcpy_s(shares_tmp->at(i)->data, shares_tmp->at(i)->len, (iter->second)[i].share.data(), | |||
| shares_tmp->at(i)->len) != 0) { | |||
| if (memcpy_s(shares_tmp->at(i)->data, SHARE_MAX_SIZE, (iter->second)[i].share.data(), shares_tmp->at(i)->len) != | |||
| 0) { | |||
| MS_LOG(ERROR) << "shares_tmp copy failed"; | |||
| retcode = false; | |||
| } | |||
| @@ -78,11 +78,15 @@ bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp, | |||
| // reconstruct pairwise noise | |||
| MS_LOG(INFO) << "start reconstruct pairwise noise."; | |||
| std::vector<float> noise(cipher_init_->featuremap_, 0.0); | |||
| if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false) | |||
| retcode = false; | |||
| client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise)); | |||
| MS_LOG(INFO) << " fl_id : " << fl_id; | |||
| MS_LOG(INFO) << "end get complete s_uv."; | |||
| if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false) { | |||
| MS_LOG(ERROR) << "GetSuvNoise failed"; | |||
| BN_clear_free(prime); | |||
| if (memset_s(secret, SECRET_MAX_LEN, 0, length) != 0) { | |||
| MS_LOG(EXCEPTION) << "Memset failed."; | |||
| } | |||
| return false; | |||
| } | |||
| (void)client_noise->emplace(std::pair<std::string, std::vector<float>>(fl_id, noise)); | |||
| } else { | |||
| // reconstruct individual noise | |||
| MS_LOG(INFO) << "start reconstruct individual noise."; | |||
| @@ -97,13 +101,23 @@ bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp, | |||
| return false; | |||
| } | |||
| std::vector<uint8_t> ind_iv = it->second[0]; | |||
| if (Masking::GetMasking(&noise, cipher_init_->featuremap_, (const uint8_t *)secret, SECRET_MAX_LEN, | |||
| ind_iv.data(), ind_iv.size()) < 0) | |||
| retcode = false; | |||
| if (Masking::GetMasking(&noise, SizeToInt(cipher_init_->featuremap_), (const uint8_t *)secret, SECRET_MAX_LEN, | |||
| ind_iv.data(), SizeToInt(ind_iv.size())) < 0) { | |||
| MS_LOG(ERROR) << "Get Masking failed"; | |||
| if (memset_s(secret, SECRET_MAX_LEN, 0, length) != 0) { | |||
| MS_LOG(EXCEPTION) << "Memset failed."; | |||
| } | |||
| BN_clear_free(prime); | |||
| return false; | |||
| } | |||
| for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) { | |||
| noise[index_noise] *= -1; | |||
| } | |||
| client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise)); | |||
| (void)client_noise->emplace(std::pair<std::string, std::vector<float>>(fl_id, noise)); | |||
| } | |||
| BN_clear_free(prime); | |||
| if (memset_s(secret, SECRET_MAX_LEN, 0, length) != 0) { | |||
| MS_LOG(EXCEPTION) << "Memset failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "reconstruct secret failed: the number of secret shares for fl_id: " << fl_id | |||
| @@ -157,6 +171,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl | |||
| std::vector<Share *> shares_tmp; | |||
| if (!MallocShares(&shares_tmp, cipher_init_->secrets_minnums_)) { | |||
| MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid."; | |||
| DeleteShares(&shares_tmp); | |||
| return false; | |||
| } | |||
| @@ -185,6 +200,24 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl | |||
| return retcode; | |||
| } | |||
| bool CipherReconStruct::CheckInputs(const schema::SendReconstructSecret *reconstruct_secret_req, | |||
| const std::shared_ptr<fl::server::FBBuilder> &fbb, const int cur_iterator, | |||
| const std::string &next_req_time) { | |||
| if (reconstruct_secret_req == nullptr) { | |||
| std::string reason = "Request is nullptr"; | |||
| MS_LOG(ERROR) << reason; | |||
| BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time); | |||
| return false; | |||
| } | |||
| if (cipher_init_ == nullptr) { | |||
| std::string reason = "cipher_init_ is nullptr"; | |||
| MS_LOG(ERROR) << reason; | |||
| BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError, reason, cur_iterator, next_req_time); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| // reconstruct secrets | |||
| bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, | |||
| const schema::SendReconstructSecret *reconstruct_secret_req, | |||
| @@ -192,13 +225,8 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st | |||
| const std::vector<std::string> &client_list) { | |||
| MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START"; | |||
| clock_t start_time = clock(); | |||
| if (reconstruct_secret_req == nullptr) { | |||
| std::string reason = "Request is nullptr"; | |||
| MS_LOG(ERROR) << reason; | |||
| BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time); | |||
| return false; | |||
| } | |||
| bool inputs_check = CheckInputs(reconstruct_secret_req, fbb, cur_iterator, next_req_time); | |||
| if (!inputs_check) return false; | |||
| int iterator = reconstruct_secret_req->iteration(); | |||
| std::string fl_id = reconstruct_secret_req->fl_id()->str(); | |||
| @@ -268,7 +296,8 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st | |||
| BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, | |||
| "Success, but the server is not ready to reconstruct secret yet.", cur_iterator, | |||
| next_req_time); | |||
| MS_LOG(INFO) << "ReconstructSecrets " << fl_id << " Success, but count " << count_client_num << "is not enough."; | |||
| MS_LOG(INFO) << "Get reconstruct shares from " << fl_id << " Success, but count " << count_client_num | |||
| << " is not enough."; | |||
| return true; | |||
| } | |||
| const fl::PBMetadata &clients_noises_pb_out = | |||
| @@ -339,7 +368,8 @@ void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::ser | |||
| bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_share_list, | |||
| const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys, | |||
| const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs, | |||
| const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length) { | |||
| const string &fl_id, std::vector<float> *noise, const uint8_t *secret, | |||
| size_t length) { | |||
| for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) { | |||
| if (*p_key != fl_id) { | |||
| PrivateKey *privKey = KeyAgreement::FromPrivateBytes(secret, length); | |||
| @@ -357,6 +387,7 @@ bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_shar | |||
| auto iter = client_ivs.find(iv_fl_id); | |||
| if (iter == client_ivs.end()) { | |||
| MS_LOG(ERROR) << "cannot get ivs for client: " << iv_fl_id; | |||
| delete privKey; | |||
| return false; | |||
| } | |||
| if (iter->second.size() != IV_NUM) { | |||
| @@ -372,8 +403,11 @@ bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_shar | |||
| } | |||
| MS_LOG(INFO) << "private_key fl_id : " << fl_id << " public_key fl_id : " << *p_key; | |||
| uint8_t secret1[SECRET_MAX_LEN] = {0}; | |||
| if (KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(), pw_salt.size(), secret1) < | |||
| 0) { | |||
| int ret = KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(), | |||
| SizeToInt(pw_salt.size()), secret1); | |||
| delete privKey; | |||
| delete pubKey; | |||
| if (ret < 0) { | |||
| MS_LOG(ERROR) << "ComputeSharedKey failed\n"; | |||
| return false; | |||
| } | |||
| @@ -426,7 +460,7 @@ bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::v | |||
| if (des->find(src_id) == des->end()) { // src_id is not in recombined shares list | |||
| std::vector<clientshare_str> value_list; | |||
| value_list.push_back(value); | |||
| des->insert(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list)); | |||
| (void)des->emplace(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list)); | |||
| } else { | |||
| des->at(src_id).push_back(value); | |||
| } | |||
| @@ -435,19 +469,18 @@ bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::v | |||
| return true; | |||
| } | |||
| bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) { | |||
| bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, size_t shares_size) { | |||
| if (shares_tmp == nullptr) return false; | |||
| for (int i = 0; i < shares_size; ++i) { | |||
| Share *share_i = new Share(); | |||
| for (size_t i = 0; i < shares_size; ++i) { | |||
| Share *share_i = new (std::nothrow) Share(); | |||
| if (share_i == nullptr) { | |||
| MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid."; | |||
| DeleteShares(shares_tmp); | |||
| MS_LOG(ERROR) << "new Share failed."; | |||
| return false; | |||
| } | |||
| share_i->data = new uint8_t[SHARE_MAX_SIZE]; | |||
| if (share_i->data == nullptr) { | |||
| MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid."; | |||
| DeleteShares(shares_tmp); | |||
| MS_LOG(ERROR) << "malloc memory failed."; | |||
| delete share_i; | |||
| return false; | |||
| } | |||
| share_i->index = 0; | |||
| @@ -37,6 +37,7 @@ class CipherReconStruct { | |||
| public: | |||
| // initialize: get cipher_init_ | |||
| CipherReconStruct() { cipher_init_ = &CipherInit::GetInstance(); } | |||
| ~CipherReconStruct() = default; | |||
| static CipherReconStruct &GetInstance() { | |||
| static CipherReconStruct instance; | |||
| @@ -64,9 +65,9 @@ class CipherReconStruct { | |||
| bool GetSuvNoise(const std::vector<std::string> &clients_share_list, | |||
| const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys, | |||
| const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs, const string &fl_id, | |||
| std::vector<float> *noise, uint8_t *secret, int length); | |||
| std::vector<float> *noise, const uint8_t *secret, size_t length); | |||
| // malloc shares. | |||
| bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size); | |||
| bool MallocShares(std::vector<Share *> *shares_tmp, size_t shares_size); | |||
| // delete shares. | |||
| void DeleteShares(std::vector<Share *> *shares_tmp); | |||
| // convert shares from receiving clients to sending clients. | |||
| @@ -84,6 +85,9 @@ class CipherReconStruct { | |||
| const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list, | |||
| const std::vector<string> &client_list, | |||
| const std::map<std::string, std::vector<std::vector<unsigned char>>> &client_ivs); | |||
| bool CheckInputs(const schema::SendReconstructSecret *reconstruct_secret_req, | |||
| const std::shared_ptr<fl::server::FBBuilder> &fbb, const int cur_iterator, | |||
| const std::string &next_req_time); | |||
| }; | |||
| } // namespace armour | |||
| } // namespace mindspore | |||
| @@ -31,7 +31,13 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha | |||
| cur_iterator); | |||
| return false; | |||
| } | |||
| if (cipher_init_ == nullptr) { | |||
| std::string reason = "cipher_init_ is nullptr"; | |||
| MS_LOG(ERROR) << reason; | |||
| BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SystemError, reason, next_req_time, | |||
| cur_iterator); | |||
| return false; | |||
| } | |||
| // step 1: get client list and share secrets from memory server. | |||
| clock_t start_time = clock(); | |||
| @@ -89,25 +95,28 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha | |||
| } | |||
| bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, | |||
| const std::string &next_req_time) { | |||
| const std::shared_ptr<fl::server::FBBuilder> &fbb, const std::string &next_req_time) { | |||
| MS_LOG(INFO) << "CipherShares::GetSecrets START"; | |||
| clock_t start_time = clock(); | |||
| // step 0: check whether the parameters are legal. | |||
| if (get_secrets_req == nullptr) { | |||
| BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, nullptr); | |||
| BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, 0, next_req_time, nullptr); | |||
| MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr."; | |||
| return false; | |||
| } | |||
| int iteration = get_secrets_req->iteration(); | |||
| if (cipher_init_ == nullptr) { | |||
| MS_LOG(ERROR) << "cipher_init_ is nullptr"; | |||
| BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, IntToSize(iteration), next_req_time, nullptr); | |||
| return false; | |||
| } | |||
| // step 1: get client list and client shares list from memory server. | |||
| std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all; | |||
| cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares, | |||
| &encrypted_shares_all); | |||
| int iteration = get_secrets_req->iteration(); | |||
| size_t encrypted_shares_num = encrypted_shares_all.size(); | |||
| if (cipher_init_->share_secrets_threshold > encrypted_shares_num) { // the client num is not enough, return false. | |||
| BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr); | |||
| BuildGetSecretsRsp(fbb, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr); | |||
| MS_LOG(INFO) << "GetSecrets: the encrypted shares num is not enough: share_secrets_threshold: " | |||
| << cipher_init_->share_secrets_threshold << "encrypted_shares_num: " << encrypted_shares_num; | |||
| return false; | |||
| @@ -116,7 +125,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| std::string fl_id = get_secrets_req->fl_id()->str(); | |||
| // the client is not in share secrets client list. | |||
| if (encrypted_shares_all.find(fl_id) == encrypted_shares_all.end()) { | |||
| BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr); | |||
| BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr); | |||
| MS_LOG(ERROR) << "GetSecrets: client is not in share secrets client list."; | |||
| return false; | |||
| } | |||
| @@ -125,7 +134,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetSecretsClientList, fl_id); | |||
| if (!retcode_client) { | |||
| MS_LOG(ERROR) << "update get secrets clients failed"; | |||
| BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr); | |||
| BuildGetSecretsRsp(fbb, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr); | |||
| return false; | |||
| } | |||
| @@ -158,15 +167,14 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| std::vector<clientshare_str>::iterator ptr_start = encrypted_shares_add.begin(); | |||
| std::vector<clientshare_str>::iterator ptr_end = ptr_start + size_shares; | |||
| for (std::vector<clientshare_str>::iterator ptr = ptr_start; ptr < ptr_end; ++ptr) { | |||
| auto one_fl_id = get_secrets_resp_builder->CreateString(ptr->fl_id); | |||
| auto two_share = get_secrets_resp_builder->CreateVector(ptr->share.data(), ptr->share.size()); | |||
| auto one_fl_id = fbb->CreateString(ptr->fl_id); | |||
| auto two_share = fbb->CreateVector(ptr->share.data(), ptr->share.size()); | |||
| auto third_index = ptr->index; | |||
| auto one_clientshare = schema::CreateClientShare(*get_secrets_resp_builder, one_fl_id, two_share, third_index); | |||
| auto one_clientshare = schema::CreateClientShare(*fbb, one_fl_id, two_share, third_index); | |||
| encrypted_shares.push_back(one_clientshare); | |||
| } | |||
| BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time, | |||
| &encrypted_shares); | |||
| BuildGetSecretsRsp(fbb, schema::ResponseCode_SUCCEED, iteration, next_req_time, &encrypted_shares); | |||
| MS_LOG(INFO) << "CipherShares::GetSecrets Success"; | |||
| clock_t end_time = clock(); | |||
| double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); | |||
| @@ -175,20 +183,20 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| } | |||
| void CipherShares::BuildGetSecretsRsp( | |||
| const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, schema::ResponseCode retcode, int iteration, | |||
| std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) { | |||
| const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, size_t iteration, | |||
| const std::string &next_req_time, | |||
| const std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) { | |||
| int rsp_retcode = retcode; | |||
| int rsp_iteration = iteration; | |||
| auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time); | |||
| int rsp_iteration = SizeToInt(iteration); | |||
| auto rsp_next_req_time = fbb->CreateString(next_req_time); | |||
| if (encrypted_shares == nullptr) { | |||
| auto get_secrets_rsp = | |||
| schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time); | |||
| get_secrets_resp_builder->Finish(get_secrets_rsp); | |||
| auto get_secrets_rsp = schema::CreateReturnShareSecrets(*fbb, rsp_retcode, rsp_iteration, 0, rsp_next_req_time); | |||
| fbb->Finish(get_secrets_rsp); | |||
| } else { | |||
| auto encrypted_shares_rsp = get_secrets_resp_builder->CreateVector(*encrypted_shares); | |||
| auto get_secrets_rsp = CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, | |||
| encrypted_shares_rsp, rsp_next_req_time); | |||
| get_secrets_resp_builder->Finish(get_secrets_rsp); | |||
| auto encrypted_shares_rsp = fbb->CreateVector(*encrypted_shares); | |||
| auto get_secrets_rsp = | |||
| CreateReturnShareSecrets(*fbb, rsp_retcode, rsp_iteration, encrypted_shares_rsp, rsp_next_req_time); | |||
| fbb->Finish(get_secrets_rsp); | |||
| } | |||
| return; | |||
| } | |||
| @@ -207,6 +215,7 @@ void CipherShares::BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuil | |||
| void CipherShares::ClearShareSecrets() { | |||
| fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxShareSecretsClientList); | |||
| fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsEncryptedShares); | |||
| fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxGetSecretsClientList); | |||
| } | |||
| } // namespace armour | |||
| @@ -35,6 +35,7 @@ class CipherShares { | |||
| public: | |||
| // initialize: get cipher_init_ | |||
| CipherShares() { cipher_init_ = &CipherInit::GetInstance(); } | |||
| ~CipherShares() = default; | |||
| static CipherShares &GetInstance() { | |||
| static CipherShares instance; | |||
| @@ -55,9 +56,9 @@ class CipherShares { | |||
| const schema::ResponseCode retcode, const string &reason, const string &next_req_time, | |||
| const int iteration); | |||
| // build response code of get secrets. | |||
| void BuildGetSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, | |||
| const schema::ResponseCode retcode, const int iteration, std::string next_req_time, | |||
| std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares); | |||
| void BuildGetSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const size_t iteration, const string &next_req_time, | |||
| const std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares); | |||
| // clear the shared memory. | |||
| void ClearShareSecrets(); | |||
| @@ -28,17 +28,24 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) { | |||
| bool ret = cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise); | |||
| if (!ret || noise.size() != cipher_init_->featuremap_) { | |||
| MS_LOG(ERROR) << " CipherMgr UnMask ERROR"; | |||
| MS_LOG(WARNING) << "Client noises is not ready"; | |||
| return false; | |||
| } | |||
| size_t data_size = fl::server::LocalMetaStore::GetInstance().value<size_t>(fl::server::kCtxFedAvgTotalDataSize); | |||
| if (data_size == 0) { | |||
| MS_LOG(ERROR) << "FedAvgTotalDataSize equals to 0"; | |||
| return false; | |||
| } | |||
| int sum_size = 0; | |||
| for (auto iter = data.begin(); iter != data.end(); ++iter) { | |||
| int size_data = iter->second->size / sizeof(float); | |||
| if (iter->second == nullptr) { | |||
| MS_LOG(ERROR) << "AddressPtr is nullptr"; | |||
| return false; | |||
| } | |||
| size_t size_data = iter->second->size / sizeof(float); | |||
| float *in_data = reinterpret_cast<float *>(iter->second->addr); | |||
| MS_LOG(INFO) << " weight name : " << iter->first; | |||
| for (int i = 0; i < size_data; ++i) { | |||
| for (size_t i = 0; i < size_data; ++i) { | |||
| in_data[i] = in_data[i] + noise[i + sum_size] / data_size; | |||
| } | |||
| sum_size += size_data; | |||
| @@ -33,6 +33,7 @@ class CipherUnmask { | |||
| public: | |||
| // initialize: get cipher_init_ | |||
| CipherUnmask() { cipher_init_ = &CipherInit::GetInstance(); } | |||
| ~CipherUnmask() = default; | |||
| // unmask the data by secret mask. | |||
| bool UnMask(const std::map<std::string, AddressPtr> &data); | |||
| @@ -19,11 +19,11 @@ | |||
| namespace mindspore { | |||
| namespace armour { | |||
| AESEncrypt::AESEncrypt(const uint8_t *key, int key_len, const uint8_t *ivec, int ivec_len, const AES_MODE mode) { | |||
| privKey = key; | |||
| privKeyLen = key_len; | |||
| iVec = ivec; | |||
| iVecLen = ivec_len; | |||
| aesMode = mode; | |||
| priv_key_ = key; | |||
| priv_key_len_ = key_len; | |||
| ivec_ = ivec; | |||
| ivec_len_ = ivec_len; | |||
| aes_mode_ = mode; | |||
| } | |||
| AESEncrypt::~AESEncrypt() {} | |||
| @@ -42,28 +42,23 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt | |||
| #else | |||
| int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) { | |||
| int ret; | |||
| if (privKey == NULL || iVec == NULL) { | |||
| if (priv_key_ == nullptr || ivec_ == nullptr) { | |||
| MS_LOG(ERROR) << "private key or init vector is invalid."; | |||
| return -1; | |||
| } | |||
| if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) { | |||
| if (priv_key_len_ != KEY_LENGTH_16 && priv_key_len_ != KEY_LENGTH_32) { | |||
| MS_LOG(ERROR) << "key length is invalid."; | |||
| return -1; | |||
| } | |||
| if (iVecLen != AES_IV_SIZE) { | |||
| if (ivec_len_ != AES_IV_SIZE) { | |||
| MS_LOG(ERROR) << "initial vector size is invalid."; | |||
| return -1; | |||
| } | |||
| if (data == NULL || len <= 0 || encrypt_data == NULL) { | |||
| if (data == nullptr || len <= 0 || encrypt_data == nullptr || encrypt_len == nullptr) { | |||
| MS_LOG(ERROR) << "input data is invalid."; | |||
| return -1; | |||
| } | |||
| if (aesMode == AES_CBC || aesMode == AES_CTR) { | |||
| ret = evp_aes_encrypt(data, len, privKey, iVec, encrypt_data, encrypt_len); | |||
| } else { | |||
| MS_LOG(ERROR) << "Please use CBC mode or CTR mode, the other modes are not supported!"; | |||
| ret = -1; | |||
| } | |||
| ret = evp_aes_encrypt(data, len, priv_key_, ivec_, encrypt_data, encrypt_len); | |||
| if (ret != 0) { | |||
| return -1; | |||
| } | |||
| @@ -72,26 +67,26 @@ int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned c | |||
| int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) { | |||
| int ret = 0; | |||
| if (privKey == NULL || iVec == NULL) { | |||
| if (priv_key_ == nullptr || ivec_ == nullptr) { | |||
| MS_LOG(ERROR) << "private key or init vector is invalid."; | |||
| return -1; | |||
| } | |||
| if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) { | |||
| if (priv_key_len_ != KEY_LENGTH_16 && priv_key_len_ != KEY_LENGTH_32) { | |||
| MS_LOG(ERROR) << "key length is invalid."; | |||
| return -1; | |||
| } | |||
| if (iVecLen != AES_IV_SIZE) { | |||
| if (ivec_len_ != AES_IV_SIZE) { | |||
| MS_LOG(ERROR) << "initial vector size is invalid."; | |||
| return -1; | |||
| } | |||
| if (data == NULL || encrypt_len <= 0 || encrypt_data == NULL) { | |||
| if (data == nullptr || encrypt_len <= 0 || encrypt_data == nullptr || len == nullptr) { | |||
| MS_LOG(ERROR) << "input data is invalid."; | |||
| return -1; | |||
| } | |||
| if (aesMode == AES_CBC || aesMode == AES_CTR) { | |||
| ret = evp_aes_decrypt(encrypt_data, encrypt_len, privKey, iVec, data, len); | |||
| if (aes_mode_ == AES_CBC || aes_mode_ == AES_CTR) { | |||
| ret = evp_aes_decrypt(encrypt_data, encrypt_len, priv_key_, ivec_, data, len); | |||
| } else { | |||
| MS_LOG(ERROR) << "Please use CBC mode or CTR mode, the other modes are not supported!"; | |||
| MS_LOG(ERROR) << "This encryption mode is not supported!"; | |||
| } | |||
| if (ret != 1) { | |||
| return -1; | |||
| @@ -99,17 +94,16 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt | |||
| return 0; | |||
| } | |||
| int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec, | |||
| uint8_t *encrypt_data, int *encrypt_len) { | |||
| const int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec, | |||
| uint8_t *encrypt_data, int *encrypt_len) { | |||
| EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); | |||
| if (ctx == NULL) { | |||
| MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!"; | |||
| return -1; | |||
| } | |||
| int out_len; | |||
| int ret; | |||
| if (aesMode == AES_CBC) { | |||
| switch (privKeyLen) { | |||
| if (aes_mode_ == AES_CBC) { | |||
| switch (priv_key_len_) { | |||
| case KEY_LENGTH_16: | |||
| ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec); | |||
| break; | |||
| @@ -121,13 +115,12 @@ int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_ | |||
| ret = -1; | |||
| } | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| EVP_CIPHER_CTX_set_padding(ctx, EVP_PADDING_PKCS7); | |||
| } else if (aesMode == AES_CTR) { | |||
| switch (privKeyLen) { | |||
| } else if (aes_mode_ == AES_CTR) { | |||
| switch (priv_key_len_) { | |||
| case KEY_LENGTH_16: | |||
| ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec); | |||
| break; | |||
| @@ -139,25 +132,22 @@ int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_ | |||
| ret = -1; | |||
| } | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported AES mode"; | |||
| MS_LOG(ERROR) << "Unsupported encryption mode"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_EncryptUpdate fail!"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| *encrypt_len = out_len; | |||
| ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| @@ -166,17 +156,16 @@ int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_ | |||
| return 0; | |||
| } | |||
| int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec, | |||
| uint8_t *decrypt_data, int *decrypt_len) { | |||
| const int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, | |||
| const uint8_t *ivec, uint8_t *decrypt_data, int *decrypt_len) { | |||
| EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); | |||
| if (ctx == NULL) { | |||
| MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!"; | |||
| return -1; | |||
| } | |||
| int out_len; | |||
| int ret; | |||
| if (aesMode == AES_CBC) { | |||
| switch (privKeyLen) { | |||
| if (aes_mode_ == AES_CBC) { | |||
| switch (priv_key_len_) { | |||
| case KEY_LENGTH_16: | |||
| ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec); | |||
| break; | |||
| @@ -191,8 +180,8 @@ int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, cons | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| } else if (aesMode == AES_CTR) { | |||
| switch (privKeyLen) { | |||
| } else if (aes_mode_ == AES_CTR) { | |||
| switch (priv_key_len_) { | |||
| case KEY_LENGTH_16: | |||
| ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec); | |||
| break; | |||
| @@ -208,21 +197,19 @@ int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, cons | |||
| return -1; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported AES mode"; | |||
| MS_LOG(ERROR) << "Unsupported encryption mode"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_DecryptUpdate fail!"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| *decrypt_len = out_len; | |||
| ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_DecryptFinal_ex fail!"; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| @@ -43,15 +43,15 @@ class AESEncrypt : SymmetricEncrypt { | |||
| int DecryptData(const uint8_t *encrypt_data, const int encrypt_len, uint8_t *data, int *len); | |||
| private: | |||
| const uint8_t *privKey; | |||
| int privKeyLen; | |||
| const uint8_t *iVec; | |||
| int iVecLen; | |||
| AES_MODE aesMode; | |||
| int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec, | |||
| uint8_t *encrypt_data, int *encrypt_len); | |||
| int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec, | |||
| uint8_t *decrypt_data, int *decrypt_len); | |||
| const uint8_t *priv_key_; | |||
| int priv_key_len_; | |||
| const uint8_t *ivec_; | |||
| int ivec_len_; | |||
| AES_MODE aes_mode_; | |||
| const int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec, | |||
| uint8_t *encrypt_data, int *encrypt_len); | |||
| const int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec, | |||
| uint8_t *decrypt_data, int *decrypt_len); | |||
| }; | |||
| } // namespace armour | |||
| @@ -54,9 +54,9 @@ PrivateKey::PrivateKey(EVP_PKEY *evpKey) { evpPrivKey = evpKey; } | |||
| PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); } | |||
| int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) { | |||
| if (privKeyBytes == nullptr || len <= 0) { | |||
| MS_LOG(ERROR) << "input privKeyBytes invalid."; | |||
| int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) const { | |||
| if (privKeyBytes == nullptr || len == nullptr || evpPrivKey == nullptr) { | |||
| MS_LOG(ERROR) << "input data invalid."; | |||
| return -1; | |||
| } | |||
| if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) { | |||
| @@ -65,8 +65,8 @@ int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) { | |||
| return 0; | |||
| } | |||
| int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) { | |||
| if (pubKeyBytes == nullptr || len <= 0) { | |||
| int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) const { | |||
| if (pubKeyBytes == nullptr || len == nullptr || evpPrivKey == nullptr) { | |||
| MS_LOG(ERROR) << "input pubKeyBytes invalid."; | |||
| return -1; | |||
| } | |||
| @@ -90,11 +90,10 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c | |||
| MS_LOG(ERROR) << "input salt in invalid."; | |||
| return -1; | |||
| } | |||
| EVP_PKEY_CTX *ctx; | |||
| size_t len = 0; | |||
| ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL); | |||
| if (!ctx) { | |||
| MS_LOG(ERROR) << "EVP_PKEY_CTX_new failed!"; | |||
| EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL); | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "new context failed!"; | |||
| return -1; | |||
| } | |||
| if (EVP_PKEY_derive_init(ctx) <= 0) { | |||
| @@ -107,15 +106,17 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c | |||
| EVP_PKEY_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| unsigned char *secret; | |||
| if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) { | |||
| MS_LOG(ERROR) << "get derive key size failed!"; | |||
| EVP_PKEY_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| secret = (unsigned char *)OPENSSL_malloc(len); | |||
| if (!secret) { | |||
| if (len == 0) { | |||
| EVP_PKEY_CTX_free(ctx); | |||
| return -1; | |||
| } | |||
| uint8_t *secret = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len)); | |||
| if (secret == nullptr) { | |||
| MS_LOG(ERROR) << "malloc secret memory failed!"; | |||
| EVP_PKEY_CTX_free(ctx); | |||
| return -1; | |||
| @@ -142,7 +143,7 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c | |||
| PrivateKey *KeyAgreement::GeneratePrivKey() { | |||
| EVP_PKEY *evpKey = NULL; | |||
| EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_X25519, NULL); | |||
| if (!pctx) { | |||
| if (pctx == nullptr) { | |||
| return NULL; | |||
| } | |||
| if (EVP_PKEY_keygen_init(pctx) <= 0) { | |||
| @@ -168,7 +169,7 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) { | |||
| return NULL; | |||
| } | |||
| pubKeyBytes = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len)); | |||
| if (!pubKeyBytes) { | |||
| if (pubKeyBytes == nullptr) { | |||
| MS_LOG(ERROR) << "malloc secret memory failed!"; | |||
| return NULL; | |||
| } | |||
| @@ -190,7 +191,11 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) { | |||
| return pubKey; | |||
| } | |||
| PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) { | |||
| PrivateKey *KeyAgreement::FromPrivateBytes(const uint8_t *data, size_t len) { | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "input data is null!"; | |||
| return NULL; | |||
| } | |||
| EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len); | |||
| if (evp_Key == NULL) { | |||
| MS_LOG(ERROR) << "create evp_Key from raw bytes failed!"; | |||
| @@ -200,7 +205,11 @@ PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) { | |||
| return privKey; | |||
| } | |||
| PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) { | |||
| PublicKey *KeyAgreement::FromPublicBytes(const uint8_t *data, size_t len) { | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "input data is null!"; | |||
| return NULL; | |||
| } | |||
| EVP_PKEY *evp_pubKey = EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, data, len); | |||
| if (evp_pubKey == NULL) { | |||
| MS_LOG(ERROR) << "create evp_pubKey from raw bytes fail"; | |||
| @@ -48,8 +48,8 @@ class PrivateKey { | |||
| ~PrivateKey(); | |||
| int Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len, | |||
| unsigned char *exchangeKey); | |||
| int GetPrivateBytes(size_t *len, unsigned char *priKeyBytes); | |||
| int GetPublicBytes(size_t *len, unsigned char *pubKeyBytes); | |||
| int GetPrivateBytes(size_t *len, unsigned char *priKeyBytes) const; | |||
| int GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) const; | |||
| EVP_PKEY *evpPrivKey; | |||
| }; | |||
| #endif | |||
| @@ -58,8 +58,8 @@ class KeyAgreement { | |||
| public: | |||
| static PrivateKey *GeneratePrivKey(); | |||
| static PublicKey *GeneratePubKey(PrivateKey *privKey); | |||
| static PrivateKey *FromPrivateBytes(unsigned char *data, int len); | |||
| static PublicKey *FromPublicBytes(unsigned char *data, int len); | |||
| static PrivateKey *FromPrivateBytes(const unsigned char *data, size_t len); | |||
| static PublicKey *FromPublicBytes(const unsigned char *data, size_t len); | |||
| static int ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len, const unsigned char *salt, | |||
| int salt_len, unsigned char *exchangeKey); | |||
| }; | |||
| @@ -18,9 +18,9 @@ | |||
| namespace mindspore { | |||
| namespace armour { | |||
| void secure_zero(unsigned char *s, size_t n) { | |||
| volatile unsigned char *p = s; | |||
| if (p) | |||
| void secure_zero(uint8_t *s, size_t n) { | |||
| volatile uint8_t *p = s; | |||
| if (p != nullptr) | |||
| while (n--) *p++ = '\0'; | |||
| } | |||
| @@ -215,17 +215,24 @@ int SecretSharing::Combine(size_t k, const std::vector<Share *> &shares, uint8_t | |||
| } | |||
| } | |||
| if (ret == -1) { | |||
| BN_clear_free(tmp); | |||
| break; | |||
| } | |||
| (void)BN_mod_inverse(tmp, denses[j], this->bn_prim_, ctx); | |||
| if (!field_mult(tmp, tmp, nums[j], ctx)) { | |||
| ret = -1; | |||
| BN_clear_free(tmp); | |||
| break; | |||
| } | |||
| if (!field_mult(tmp, tmp, y[j], ctx)) { | |||
| ret = -1; | |||
| BN_clear_free(tmp); | |||
| break; | |||
| } | |||
| if (!field_add(sum, sum, tmp, ctx)) { | |||
| ret = -1; | |||
| BN_clear_free(tmp); | |||
| break; | |||
| } | |||
| BN_clear_free(tmp); | |||
| @@ -206,7 +206,6 @@ constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; | |||
| constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp"; | |||
| constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; | |||
| constexpr auto kCtxUpdateModelThld = "update_model_threshold"; | |||
| constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; | |||
| constexpr auto kCtxClientsKeys = "clients_keys"; | |||
| constexpr auto kCtxClientNoises = "clients_noises"; | |||
| constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares"; | |||
| @@ -54,7 +54,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient | |||
| uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld); | |||
| PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); | |||
| const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list(); | |||
| for (int i = 0; i < client_list_pb.fl_id_size(); ++i) { | |||
| for (size_t i = 0; i < IntToSize(client_list_pb.fl_id_size()); ++i) { | |||
| client_list.push_back(client_list_pb.fl_id(i)); | |||
| } | |||
| if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) { | |||
| @@ -87,7 +87,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient | |||
| BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list, | |||
| std::to_string(CURRENT_TIME_MILLI.count()), iter_num); | |||
| MS_LOG(ERROR) << reason; | |||
| return true; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "send clients_list succeed!"; | |||
| MS_LOG(INFO) << "UpdateModel client list: "; | |||
| @@ -100,7 +100,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient | |||
| return true; | |||
| } | |||
| bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); | |||
| @@ -148,6 +148,7 @@ bool ClientListKernel::Reset() { | |||
| MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); | |||
| MS_LOG(INFO) << "Get Client list kernel reset!"; | |||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||
| DistributedMetadataStore::GetInstance().ResetMetadata(kCtxGetUpdateModelClientList); | |||
| StopTimer(); | |||
| return true; | |||
| } | |||
| @@ -170,7 +171,7 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli | |||
| rsp_builder.add_retcode(retcode); | |||
| rsp_builder.add_reason(rsp_reason); | |||
| rsp_builder.add_clients(clients_fb); | |||
| rsp_builder.add_iteration(iteration); | |||
| rsp_builder.add_iteration(SizeToInt(iteration)); | |||
| rsp_builder.add_next_req_time(rsp_next_req_time); | |||
| auto rsp_exchange_keys = rsp_builder.Finish(); | |||
| client_list_resp_builder->Finish(rsp_exchange_keys); | |||
| @@ -34,7 +34,7 @@ class ClientListKernel : public RoundKernel { | |||
| ClientListKernel() = default; | |||
| ~ClientListKernel() override = default; | |||
| void InitKernel(size_t required_cnt) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| bool Reset() override; | |||
| void BuildClientListRsp(std::shared_ptr<server::FBBuilder> client_list_resp_builder, | |||
| @@ -65,7 +65,7 @@ bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> & | |||
| return true; | |||
| } | |||
| bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| MS_LOG(INFO) << "Launching ExchangeKey kernel."; | |||
| bool response = false; | |||
| @@ -35,7 +35,7 @@ class ExchangeKeysKernel : public RoundKernel { | |||
| ExchangeKeysKernel() = default; | |||
| ~ExchangeKeysKernel() override = default; | |||
| void InitKernel(size_t required_cnt) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| bool Reset() override; | |||
| @@ -51,7 +51,7 @@ bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const | |||
| return true; | |||
| } | |||
| bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| MS_LOG(INFO) << "Launching GetKeys kernel."; | |||
| bool response = false; | |||
| @@ -35,7 +35,7 @@ class GetKeysKernel : public RoundKernel { | |||
| GetKeysKernel() = default; | |||
| ~GetKeysKernel() override = default; | |||
| void InitKernel(size_t required_cnt) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| bool Reset() override; | |||
| @@ -18,6 +18,8 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <map> | |||
| #include <utility> | |||
| #include "fl/armour/cipher/cipher_shares.h" | |||
| namespace mindspore { | |||
| @@ -52,7 +54,7 @@ bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, | |||
| return true; | |||
| } | |||
| bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| bool response = false; | |||
| size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| @@ -34,7 +34,7 @@ class GetSecretsKernel : public RoundKernel { | |||
| GetSecretsKernel() = default; | |||
| ~GetSecretsKernel() override = default; | |||
| void InitKernel(size_t required_cnt) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| bool Reset() override; | |||
| @@ -50,7 +50,7 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) { | |||
| {first_cnt_handler, last_cnt_handler}); | |||
| } | |||
| bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| bool response = false; | |||
| size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| @@ -80,7 +80,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con | |||
| DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); | |||
| const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list(); | |||
| for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) { | |||
| for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) { | |||
| update_model_clients.push_back(update_model_clients_pb.fl_id(i)); | |||
| } | |||
| @@ -36,7 +36,7 @@ class ReconstructSecretsKernel : public RoundKernel { | |||
| ~ReconstructSecretsKernel() override = default; | |||
| void InitKernel(size_t required_cnt) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| bool Reset() override; | |||
| void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override; | |||
| @@ -44,7 +44,7 @@ class ReconstructSecretsKernel : public RoundKernel { | |||
| private: | |||
| std::string name_unmask_; | |||
| Executor *executor_; | |||
| size_t iteration_time_window_; | |||
| size_t iteration_time_window_{0}; | |||
| armour::CipherReconStruct cipher_reconstruct_; | |||
| }; | |||
| } // namespace kernel | |||
| @@ -38,8 +38,7 @@ void ShareSecretsKernel::InitKernel(size_t) { | |||
| bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestShareSecrets *share_secrets_req, | |||
| const int iter_num) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(share_secrets_req, false); | |||
| const size_t iter_num) { | |||
| if (!DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str())) { | |||
| std::string reason = "Counting for share secret kernel request failed. Please retry later."; | |||
| cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason, | |||
| @@ -50,7 +49,7 @@ bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> & | |||
| return true; | |||
| } | |||
| bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| bool response = false; | |||
| size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| @@ -35,7 +35,7 @@ class ShareSecretsKernel : public RoundKernel { | |||
| ShareSecretsKernel() = default; | |||
| ~ShareSecretsKernel() override = default; | |||
| void InitKernel(size_t required_cnt) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| bool Reset() override; | |||
| @@ -44,7 +44,7 @@ class ShareSecretsKernel : public RoundKernel { | |||
| size_t iteration_time_window_; | |||
| armour::CipherShares *cipher_share_; | |||
| bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req, | |||
| const int iter_num); | |||
| const size_t iter_num); | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| @@ -118,7 +118,6 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| } | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| @@ -245,38 +245,18 @@ void Server::InitIteration() { | |||
| void Server::InitCipher() { | |||
| #ifdef ENABLE_ARMOUR | |||
| cipher_init_ = &armour::CipherInit::GetInstance(); | |||
| int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_); | |||
| unsigned char cipher_p[SECRET_MAX_LEN] = {0}; | |||
| const int cipher_g = 1; | |||
| unsigned char cipher_prime[PRIME_MAX_LEN] = {0}; | |||
| float dp_eps = ps::PSContext::instance()->dp_eps(); | |||
| float dp_delta = ps::PSContext::instance()->dp_delta(); | |||
| float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip(); | |||
| std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); | |||
| BIGNUM *prim = BN_new(); | |||
| if (prim == nullptr) { | |||
| MS_LOG(EXCEPTION) << "new bn failed"; | |||
| } | |||
| mindspore::armour::GetPrime(prim); | |||
| MS_LOG(INFO) << "prime" << BN_bn2hex(prim); | |||
| (void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(cipher_prime)); | |||
| if (prim != nullptr) { | |||
| BN_clear_free(prim); | |||
| } | |||
| mindspore::armour::CipherPublicPara param; | |||
| param.g = cipher_g; | |||
| param.t = cipher_t; | |||
| int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| ret = memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN); | |||
| int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p)); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| @@ -285,6 +265,23 @@ void Server::InitCipher() { | |||
| param.dp_eps = dp_eps; | |||
| param.dp_norm_clip = dp_norm_clip; | |||
| param.encrypt_type = encrypt_type; | |||
| BIGNUM *prim = BN_new(); | |||
| if (prim == NULL) { | |||
| MS_LOG(EXCEPTION) << "new bn failed."; | |||
| ret = -1; | |||
| } else { | |||
| ret = mindspore::armour::GetPrime(prim); | |||
| } | |||
| if (ret == 0) { | |||
| (void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(param.prime)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Get prime failed."; | |||
| } | |||
| if (prim != NULL) { | |||
| BN_clear_free(prim); | |||
| } | |||
| cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_, | |||
| cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_, | |||
| cipher_reconstruct_secrets_up_cnt_); | |||