Browse Source

Fix code-warnings of federated's secure aggregation

tags/v1.5.0-rc1
jin-xiulang 4 years ago
parent
commit
6a1aa8e4c5
31 changed files with 342 additions and 236 deletions
  1. +5
    -1
      mindspore/ccsrc/fl/armour/cipher/cipher_init.cc
  2. +13
    -10
      mindspore/ccsrc/fl/armour/cipher/cipher_init.h
  3. +29
    -19
      mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc
  4. +7
    -6
      mindspore/ccsrc/fl/armour/cipher/cipher_keys.h
  5. +49
    -16
      mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc
  6. +63
    -30
      mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc
  7. +6
    -2
      mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.h
  8. +34
    -25
      mindspore/ccsrc/fl/armour/cipher/cipher_shares.cc
  9. +4
    -3
      mindspore/ccsrc/fl/armour/cipher/cipher_shares.h
  10. +11
    -4
      mindspore/ccsrc/fl/armour/cipher/cipher_unmask.cc
  11. +1
    -0
      mindspore/ccsrc/fl/armour/cipher/cipher_unmask.h
  12. +31
    -44
      mindspore/ccsrc/fl/armour/secure_protocol/encrypt.cc
  13. +9
    -9
      mindspore/ccsrc/fl/armour/secure_protocol/encrypt.h
  14. +26
    -17
      mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.cc
  15. +4
    -4
      mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.h
  16. +10
    -3
      mindspore/ccsrc/fl/armour/secure_protocol/secret_sharing.cc
  17. +0
    -1
      mindspore/ccsrc/fl/server/common.h
  18. +5
    -4
      mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc
  19. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.h
  20. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc
  21. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.h
  22. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc
  23. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.h
  24. +3
    -1
      mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc
  25. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.h
  26. +2
    -2
      mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc
  27. +2
    -2
      mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.h
  28. +2
    -3
      mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc
  29. +2
    -2
      mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.h
  30. +0
    -1
      mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc
  31. +18
    -21
      mindspore/ccsrc/fl/server/server.cc

+ 5
- 1
mindspore/ccsrc/fl/armour/cipher/cipher_init.cc View File

@@ -27,7 +27,11 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
size_t cipher_reconstruct_secrets_up_cnt) {
MS_LOG(INFO) << "CipherInit::Init START";
if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN) != 0) {
if (publicparam_.p == nullptr || param.p == nullptr || param.prime == nullptr || publicparam_.prime == nullptr) {
MS_LOG(ERROR) << "CipherInit::input data invalid.";
return false;
}
if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, sizeof(param.p)) != 0) {
MS_LOG(ERROR) << "CipherInit::memory copy failed.";
return false;
}


+ 13
- 10
mindspore/ccsrc/fl/armour/cipher/cipher_init.h View File

@@ -43,25 +43,28 @@ class CipherInit {
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
size_t cipher_reconstruct_secrets_up_cnt);

// Check whether the parameters are valid.
bool Check_Parames();

// Get public params. which is given to start fl job thread.
CipherPublicPara *GetPublicParams() { return &publicparam_; }

size_t share_secrets_threshold; // the minimum number of clients to share secret fragments.
size_t get_secrets_threshold; // the minimum number of clients to get secret fragments.
size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask.
size_t exchange_key_threshold; // the minimum number of clients to send public keys.
size_t get_key_threshold; // the minimum number of clients to get public keys.
size_t client_list_threshold; // the minimum number of clients to get update model client list.

size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
size_t featuremap_; // the size of data to deal.
size_t time_out_mutex_; // timeout mutex.
size_t push_list_sign_threshold; // the minimum number of clients to push client list signature.
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
size_t featuremap_; // the size of data to deal.

CipherPublicPara publicparam_; // the param containing encrypted public parameters.
CipherMetaStorage cipher_meta_storage_;

private:
size_t client_list_threshold; // the minimum number of clients to get update model client list.
size_t get_key_threshold; // the minimum number of clients to get public keys.
size_t get_list_sign_threshold; // the minimum number of clients to get client list signature.
size_t get_secrets_threshold; // the minimum number of clients to get secret fragments.
size_t time_out_mutex_; // timeout mutex.

// Check whether the parameters are valid.
bool Check_Parames();
};
} // namespace armour
} // namespace mindspore


+ 29
- 19
mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc View File

@@ -19,12 +19,16 @@

namespace mindspore {
namespace armour {
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
bool CipherKeys::GetKeys(const size_t cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
MS_LOG(INFO) << "CipherMgr::GetKeys START";
if (get_exchange_keys_req == nullptr) {
MS_LOG(ERROR) << "Request is nullptr";
BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
return false;
}
if (cipher_init_ == nullptr) {
BuildGetKeysRsp(fbb, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
return false;
}
@@ -61,7 +65,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
return true;
}

bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
@@ -72,6 +76,11 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, next_req_time, cur_iterator);
return false;
}
if (cipher_init_ == nullptr) {
std::string reason = "cipher_init_ is nullptr";
BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, reason, next_req_time, cur_iterator);
return false;
}
std::string fl_id = exchange_keys_req->fl_id()->str();
mindspore::fl::PBMetadata device_metas =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxDeviceMetas);
@@ -127,9 +136,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
}
}

void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time,
const int iteration) {
void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb,
const schema::ResponseCode retcode, const std::string &reason,
const std::string &next_req_time, const size_t iteration) {
auto rsp_reason = fbb->CreateString(reason);
auto rsp_next_req_time = fbb->CreateString(next_req_time);

@@ -137,21 +146,21 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_next_req_time(rsp_next_req_time);
rsp_builder.add_iteration(iteration);
rsp_builder.add_iteration(SizeToInt(iteration));
auto rsp_exchange_keys = rsp_builder.Finish();
fbb->Finish(rsp_exchange_keys);
return;
}

void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good) {
void CipherKeys::BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const size_t iteration, const std::string &next_req_time, bool is_good) {
if (!is_good) {
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
schema::ReturnExchangeKeysBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(static_cast<int>(retcode));
rsp_builder.add_iteration(SizeToInt(iteration));
rsp_builder.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_builder.Finish();
fbb->Finish(rsp_get_keys);
return;
}
@@ -176,12 +185,12 @@ void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, con
}
auto remote_publickeys = fbb->CreateVector(public_keys_list);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_remote_publickeys(remote_publickeys);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
schema::ReturnExchangeKeysBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(static_cast<int>(retcode));
rsp_builder.add_iteration(SizeToInt(iteration));
rsp_builder.add_remote_publickeys(remote_publickeys);
rsp_builder.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_builder.Finish();
fbb->Finish(rsp_get_keys);
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
return;
@@ -190,6 +199,7 @@ void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, con
void CipherKeys::ClearKeys() {
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxExChangeKeysClientList);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsKeys);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxGetKeysClientList);
}

} // namespace armour


+ 7
- 6
mindspore/ccsrc/fl/armour/cipher/cipher_keys.h View File

@@ -36,6 +36,7 @@ class CipherKeys {
public:
// initialize: get cipher_init_
CipherKeys() { cipher_init_ = &CipherInit::GetInstance(); }
~CipherKeys() = default;

static CipherKeys &GetInstance() {
static CipherKeys instance;
@@ -43,20 +44,20 @@ class CipherKeys {
}

// handle the client's request of get keys.
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
bool GetKeys(const size_t cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req, const std::shared_ptr<fl::server::FBBuilder> &fbb);

// handle the client's request of exchange keys.
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
bool ExchangeKeys(const size_t cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb);

// build response code of get keys.
void BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good);
void BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const size_t iteration, const std::string &next_req_time, bool is_good);
// build response code of exchange keys.
void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time, const int iteration);
void BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time, const size_t iteration);
// clear the shared memory.
void ClearKeys();



+ 49
- 16
mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc View File

@@ -21,6 +21,10 @@ namespace armour {

void CipherMetaStorage::GetClientSharesFromServer(
const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
if (clients_shares_list == nullptr) {
MS_LOG(ERROR) << "input clients_shares_list is nullptr";
return;
}
const fl::PBMetadata &clients_shares_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
@@ -29,7 +33,8 @@ void CipherMetaStorage::GetClientSharesFromServer(
std::string fl_id = iter->first;
const fl::SharesPb &shares_pb = iter->second;
std::vector<clientshare_str> encrpted_shares_new;
for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) {
size_t client_share_num = IntToSize(shares_pb.clientsharestrs_size());
for (size_t index_shares = 0; index_shares < client_share_num; ++index_shares) {
const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
clientshare_str new_clientshare;
new_clientshare.fl_id = client_share_str_pb.fl_id();
@@ -42,9 +47,14 @@ void CipherMetaStorage::GetClientSharesFromServer(
}

void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
if (clients_list == nullptr) {
MS_LOG(ERROR) << "input clients_list is nullptr";
return;
}
const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
size_t client_list_num = IntToSize(client_list_pb.fl_id_size());
for (size_t i = 0; i < client_list_num; ++i) {
std::string fl_id = client_list_pb.fl_id(i);
clients_list->push_back(fl_id);
}
@@ -52,6 +62,10 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect

void CipherMetaStorage::GetClientKeysFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
if (clients_keys_list == nullptr) {
MS_LOG(ERROR) << "input clients_keys_list is nullptr";
return;
}
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
@@ -65,12 +79,16 @@ void CipherMetaStorage::GetClientKeysFromServer(
std::vector<std::vector<uint8_t>> cur_keys;
cur_keys.push_back(cpk);
cur_keys.push_back(spk);
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
(void)clients_keys_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
}
}

void CipherMetaStorage::GetClientIVsFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
if (clients_ivs_list == nullptr) {
MS_LOG(ERROR) << "input clients_ivs_list is nullptr";
return;
}
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
@@ -86,25 +104,28 @@ void CipherMetaStorage::GetClientIVsFromServer(
cur_ivs.push_back(ind_iv);
cur_ivs.push_back(pw_iv);
cur_ivs.push_back(pw_salt);
clients_ivs_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
(void)clients_ivs_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
}
}

bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
if (cur_public_noise == nullptr) {
MS_LOG(ERROR) << "input cur_public_noise is nullptr";
return false;
}
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
int count = 0;
int count_thld = 100;
const int count_thld = 1000;
while (clients_noises_pb.has_one_client_noises() == false) {
int register_time = 500;
const int register_time = 500;
std::this_thread::sleep_for(std::chrono::milliseconds(register_time));
count++;
if (count >= count_thld) break;
}
MS_LOG(INFO) << "GetClientNoisesFromServer Count: " << count;
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(ERROR) << "GetClientNoisesFromServer NULL.";
MS_LOG(WARNING) << "GetClientNoisesFromServer Count: " << count;
return false;
}
cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
@@ -112,17 +133,22 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
return true;
}

bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, uint8_t *prime) {
if (prime == nullptr) {
MS_LOG(ERROR) << "input prime is nullptr";
return false;
}
const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
fl::Prime prime_pb(prime_pb_out.prime());
std::string str = *(prime_pb.mutable_prime());
MS_LOG(INFO) << "get prime from metastorage :" << str;

if (str.size() != PRIME_MAX_LEN) {
MS_LOG(ERROR) << "get prime size is :" << str.size();
return false;
} else {
memcpy_s(prime, PRIME_MAX_LEN, str.data(), PRIME_MAX_LEN);
if (memcpy_s(prime, PRIME_MAX_LEN, str.data(), str.size()) != 0) {
MS_LOG(ERROR) << "Memcpy_s error";
return false;
}
return true;
}
}
@@ -143,12 +169,13 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &
fl::PBMetadata prime_pb;
prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
sleep(1);
uint32_t time = 1;
(void)sleep(time);
}

bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
const std::vector<std::vector<uint8_t>> &cur_public_key) {
size_t correct_size = 2;
const size_t correct_size = 2;
if (cur_public_key.size() < correct_size) {
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
return false;
@@ -245,14 +272,18 @@ bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const s
bool CipherMetaStorage::UpdateClientShareToServer(
const char *list_name, const std::string &fl_id,
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
int size_shares = shares->size();
if (shares == nullptr) {
return false;
}
size_t size_shares = shares->size();
fl::SharesPb shares_pb;
for (int index = 0; index < size_shares; ++index) {
for (size_t index = 0; index < size_shares; ++index) {
// new item
fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
std::string fl_id_new = (*shares)[index]->fl_id()->str();
int index_new = (*shares)[index]->index();
auto share = (*shares)[index]->share();
if (share == nullptr) return false;
client_share_str_new_p->set_share(reinterpret_cast<const char *>(share->data()), share->size());
client_share_str_new_p->set_fl_id(fl_id_new);
client_share_str_new_p->set_index(index_new);
@@ -293,6 +324,8 @@ void CipherMetaStorage::RegisterClass() {
fl::PBMetadata get_update_clients_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList,
get_update_clients_list);
fl::PBMetadata client_noises;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientNoises, client_noises);
}
} // namespace armour
} // namespace mindspore

+ 63
- 30
mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc View File

@@ -58,8 +58,8 @@ bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp,
for (int i = 0; i < static_cast<int>(cipher_init_->secrets_minnums_); ++i) {
shares_tmp->at(i)->index = (iter->second)[i].index;
shares_tmp->at(i)->len = (iter->second)[i].share.size();
if (memcpy_s(shares_tmp->at(i)->data, shares_tmp->at(i)->len, (iter->second)[i].share.data(),
shares_tmp->at(i)->len) != 0) {
if (memcpy_s(shares_tmp->at(i)->data, SHARE_MAX_SIZE, (iter->second)[i].share.data(), shares_tmp->at(i)->len) !=
0) {
MS_LOG(ERROR) << "shares_tmp copy failed";
retcode = false;
}
@@ -78,11 +78,15 @@ bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp,
// reconstruct pairwise noise
MS_LOG(INFO) << "start reconstruct pairwise noise.";
std::vector<float> noise(cipher_init_->featuremap_, 0.0);
if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false)
retcode = false;
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
MS_LOG(INFO) << " fl_id : " << fl_id;
MS_LOG(INFO) << "end get complete s_uv.";
if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false) {
MS_LOG(ERROR) << "GetSuvNoise failed";
BN_clear_free(prime);
if (memset_s(secret, SECRET_MAX_LEN, 0, length) != 0) {
MS_LOG(EXCEPTION) << "Memset failed.";
}
return false;
}
(void)client_noise->emplace(std::pair<std::string, std::vector<float>>(fl_id, noise));
} else {
// reconstruct individual noise
MS_LOG(INFO) << "start reconstruct individual noise.";
@@ -97,13 +101,23 @@ bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp,
return false;
}
std::vector<uint8_t> ind_iv = it->second[0];
if (Masking::GetMasking(&noise, cipher_init_->featuremap_, (const uint8_t *)secret, SECRET_MAX_LEN,
ind_iv.data(), ind_iv.size()) < 0)
retcode = false;
if (Masking::GetMasking(&noise, SizeToInt(cipher_init_->featuremap_), (const uint8_t *)secret, SECRET_MAX_LEN,
ind_iv.data(), SizeToInt(ind_iv.size())) < 0) {
MS_LOG(ERROR) << "Get Masking failed";
if (memset_s(secret, SECRET_MAX_LEN, 0, length) != 0) {
MS_LOG(EXCEPTION) << "Memset failed.";
}
BN_clear_free(prime);
return false;
}
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
noise[index_noise] *= -1;
}
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
(void)client_noise->emplace(std::pair<std::string, std::vector<float>>(fl_id, noise));
}
BN_clear_free(prime);
if (memset_s(secret, SECRET_MAX_LEN, 0, length) != 0) {
MS_LOG(EXCEPTION) << "Memset failed.";
}
} else {
MS_LOG(ERROR) << "reconstruct secret failed: the number of secret shares for fl_id: " << fl_id
@@ -157,6 +171,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
std::vector<Share *> shares_tmp;
if (!MallocShares(&shares_tmp, cipher_init_->secrets_minnums_)) {
MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid.";
DeleteShares(&shares_tmp);
return false;
}

@@ -185,6 +200,24 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
return retcode;
}

bool CipherReconStruct::CheckInputs(const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb, const int cur_iterator,
const std::string &next_req_time) {
if (reconstruct_secret_req == nullptr) {
std::string reason = "Request is nullptr";
MS_LOG(ERROR) << reason;
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time);
return false;
}
if (cipher_init_ == nullptr) {
std::string reason = "cipher_init_ is nullptr";
MS_LOG(ERROR) << reason;
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError, reason, cur_iterator, next_req_time);
return false;
}
return true;
}

// reconstruct secrets
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req,
@@ -192,13 +225,8 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
const std::vector<std::string> &client_list) {
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
clock_t start_time = clock();

if (reconstruct_secret_req == nullptr) {
std::string reason = "Request is nullptr";
MS_LOG(ERROR) << reason;
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time);
return false;
}
bool inputs_check = CheckInputs(reconstruct_secret_req, fbb, cur_iterator, next_req_time);
if (!inputs_check) return false;

int iterator = reconstruct_secret_req->iteration();
std::string fl_id = reconstruct_secret_req->fl_id()->str();
@@ -268,7 +296,8 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
"Success, but the server is not ready to reconstruct secret yet.", cur_iterator,
next_req_time);
MS_LOG(INFO) << "ReconstructSecrets " << fl_id << " Success, but count " << count_client_num << "is not enough.";
MS_LOG(INFO) << "Get reconstruct shares from " << fl_id << " Success, but count " << count_client_num
<< " is not enough.";
return true;
}
const fl::PBMetadata &clients_noises_pb_out =
@@ -339,7 +368,8 @@ void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::ser
bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs,
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length) {
const string &fl_id, std::vector<float> *noise, const uint8_t *secret,
size_t length) {
for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) {
if (*p_key != fl_id) {
PrivateKey *privKey = KeyAgreement::FromPrivateBytes(secret, length);
@@ -357,6 +387,7 @@ bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_shar
auto iter = client_ivs.find(iv_fl_id);
if (iter == client_ivs.end()) {
MS_LOG(ERROR) << "cannot get ivs for client: " << iv_fl_id;
delete privKey;
return false;
}
if (iter->second.size() != IV_NUM) {
@@ -372,8 +403,11 @@ bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_shar
}
MS_LOG(INFO) << "private_key fl_id : " << fl_id << " public_key fl_id : " << *p_key;
uint8_t secret1[SECRET_MAX_LEN] = {0};
if (KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(), pw_salt.size(), secret1) <
0) {
int ret = KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(),
SizeToInt(pw_salt.size()), secret1);
delete privKey;
delete pubKey;
if (ret < 0) {
MS_LOG(ERROR) << "ComputeSharedKey failed\n";
return false;
}
@@ -426,7 +460,7 @@ bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::v
if (des->find(src_id) == des->end()) { // src_id is not in recombined shares list
std::vector<clientshare_str> value_list;
value_list.push_back(value);
des->insert(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list));
(void)des->emplace(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list));
} else {
des->at(src_id).push_back(value);
}
@@ -435,19 +469,18 @@ bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::v
return true;
}

bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) {
bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, size_t shares_size) {
if (shares_tmp == nullptr) return false;
for (int i = 0; i < shares_size; ++i) {
Share *share_i = new Share();
for (size_t i = 0; i < shares_size; ++i) {
Share *share_i = new (std::nothrow) Share();
if (share_i == nullptr) {
MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid.";
DeleteShares(shares_tmp);
MS_LOG(ERROR) << "new Share failed.";
return false;
}
share_i->data = new uint8_t[SHARE_MAX_SIZE];
if (share_i->data == nullptr) {
MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid.";
DeleteShares(shares_tmp);
MS_LOG(ERROR) << "malloc memory failed.";
delete share_i;
return false;
}
share_i->index = 0;


+ 6
- 2
mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.h View File

@@ -37,6 +37,7 @@ class CipherReconStruct {
public:
// initialize: get cipher_init_
CipherReconStruct() { cipher_init_ = &CipherInit::GetInstance(); }
~CipherReconStruct() = default;

static CipherReconStruct &GetInstance() {
static CipherReconStruct instance;
@@ -64,9 +65,9 @@ class CipherReconStruct {
bool GetSuvNoise(const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs, const string &fl_id,
std::vector<float> *noise, uint8_t *secret, int length);
std::vector<float> *noise, const uint8_t *secret, size_t length);
// malloc shares.
bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size);
bool MallocShares(std::vector<Share *> *shares_tmp, size_t shares_size);
// delete shares.
void DeleteShares(std::vector<Share *> *shares_tmp);
// convert shares from receiving clients to sending clients.
@@ -84,6 +85,9 @@ class CipherReconStruct {
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
const std::vector<string> &client_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &client_ivs);
bool CheckInputs(const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb, const int cur_iterator,
const std::string &next_req_time);
};
} // namespace armour
} // namespace mindspore


+ 34
- 25
mindspore/ccsrc/fl/armour/cipher/cipher_shares.cc View File

@@ -31,7 +31,13 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
cur_iterator);
return false;
}

if (cipher_init_ == nullptr) {
std::string reason = "cipher_init_ is nullptr";
MS_LOG(ERROR) << reason;
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SystemError, reason, next_req_time,
cur_iterator);
return false;
}
// step 1: get client list and share secrets from memory server.
clock_t start_time = clock();

@@ -89,25 +95,28 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
}

bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
const std::string &next_req_time) {
const std::shared_ptr<fl::server::FBBuilder> &fbb, const std::string &next_req_time) {
MS_LOG(INFO) << "CipherShares::GetSecrets START";
clock_t start_time = clock();
// step 0: check whether the parameters are legal.
if (get_secrets_req == nullptr) {
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, nullptr);
BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, 0, next_req_time, nullptr);
MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr.";
return false;
}

int iteration = get_secrets_req->iteration();
if (cipher_init_ == nullptr) {
MS_LOG(ERROR) << "cipher_init_ is nullptr";
BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, IntToSize(iteration), next_req_time, nullptr);
return false;
}
// step 1: get client list and client shares list from memory server.
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all);
int iteration = get_secrets_req->iteration();
size_t encrypted_shares_num = encrypted_shares_all.size();
if (cipher_init_->share_secrets_threshold > encrypted_shares_num) { // the client num is not enough, return false.
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
BuildGetSecretsRsp(fbb, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
MS_LOG(INFO) << "GetSecrets: the encrypted shares num is not enough: share_secrets_threshold: "
<< cipher_init_->share_secrets_threshold << "encrypted_shares_num: " << encrypted_shares_num;
return false;
@@ -116,7 +125,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::string fl_id = get_secrets_req->fl_id()->str();
// the client is not in share secrets client list.
if (encrypted_shares_all.find(fl_id) == encrypted_shares_all.end()) {
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr);
BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr);
MS_LOG(ERROR) << "GetSecrets: client is not in share secrets client list.";
return false;
}
@@ -125,7 +134,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetSecretsClientList, fl_id);
if (!retcode_client) {
MS_LOG(ERROR) << "update get secrets clients failed";
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
BuildGetSecretsRsp(fbb, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
return false;
}

@@ -158,15 +167,14 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::vector<clientshare_str>::iterator ptr_start = encrypted_shares_add.begin();
std::vector<clientshare_str>::iterator ptr_end = ptr_start + size_shares;
for (std::vector<clientshare_str>::iterator ptr = ptr_start; ptr < ptr_end; ++ptr) {
auto one_fl_id = get_secrets_resp_builder->CreateString(ptr->fl_id);
auto two_share = get_secrets_resp_builder->CreateVector(ptr->share.data(), ptr->share.size());
auto one_fl_id = fbb->CreateString(ptr->fl_id);
auto two_share = fbb->CreateVector(ptr->share.data(), ptr->share.size());
auto third_index = ptr->index;
auto one_clientshare = schema::CreateClientShare(*get_secrets_resp_builder, one_fl_id, two_share, third_index);
auto one_clientshare = schema::CreateClientShare(*fbb, one_fl_id, two_share, third_index);
encrypted_shares.push_back(one_clientshare);
}

BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time,
&encrypted_shares);
BuildGetSecretsRsp(fbb, schema::ResponseCode_SUCCEED, iteration, next_req_time, &encrypted_shares);
MS_LOG(INFO) << "CipherShares::GetSecrets Success";
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
@@ -175,20 +183,20 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
}

void CipherShares::BuildGetSecretsRsp(
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, size_t iteration,
const std::string &next_req_time,
const std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
int rsp_retcode = retcode;
int rsp_iteration = iteration;
auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time);
int rsp_iteration = SizeToInt(iteration);
auto rsp_next_req_time = fbb->CreateString(next_req_time);
if (encrypted_shares == nullptr) {
auto get_secrets_rsp =
schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time);
get_secrets_resp_builder->Finish(get_secrets_rsp);
auto get_secrets_rsp = schema::CreateReturnShareSecrets(*fbb, rsp_retcode, rsp_iteration, 0, rsp_next_req_time);
fbb->Finish(get_secrets_rsp);
} else {
auto encrypted_shares_rsp = get_secrets_resp_builder->CreateVector(*encrypted_shares);
auto get_secrets_rsp = CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration,
encrypted_shares_rsp, rsp_next_req_time);
get_secrets_resp_builder->Finish(get_secrets_rsp);
auto encrypted_shares_rsp = fbb->CreateVector(*encrypted_shares);
auto get_secrets_rsp =
CreateReturnShareSecrets(*fbb, rsp_retcode, rsp_iteration, encrypted_shares_rsp, rsp_next_req_time);
fbb->Finish(get_secrets_rsp);
}
return;
}
@@ -207,6 +215,7 @@ void CipherShares::BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuil
void CipherShares::ClearShareSecrets() {
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxShareSecretsClientList);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsEncryptedShares);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxGetSecretsClientList);
}

} // namespace armour


+ 4
- 3
mindspore/ccsrc/fl/armour/cipher/cipher_shares.h View File

@@ -35,6 +35,7 @@ class CipherShares {
public:
// initialize: get cipher_init_
CipherShares() { cipher_init_ = &CipherInit::GetInstance(); }
~CipherShares() = default;

static CipherShares &GetInstance() {
static CipherShares instance;
@@ -55,9 +56,9 @@ class CipherShares {
const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
const int iteration);
// build response code of get secrets.
void BuildGetSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
void BuildGetSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const size_t iteration, const string &next_req_time,
const std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
// clear the shared memory.
void ClearShareSecrets();



+ 11
- 4
mindspore/ccsrc/fl/armour/cipher/cipher_unmask.cc View File

@@ -28,17 +28,24 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {

bool ret = cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
if (!ret || noise.size() != cipher_init_->featuremap_) {
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
MS_LOG(WARNING) << "Client noises is not ready";
return false;
}

size_t data_size = fl::server::LocalMetaStore::GetInstance().value<size_t>(fl::server::kCtxFedAvgTotalDataSize);
if (data_size == 0) {
MS_LOG(ERROR) << "FedAvgTotalDataSize equals to 0";
return false;
}
int sum_size = 0;
for (auto iter = data.begin(); iter != data.end(); ++iter) {
int size_data = iter->second->size / sizeof(float);
if (iter->second == nullptr) {
MS_LOG(ERROR) << "AddressPtr is nullptr";
return false;
}
size_t size_data = iter->second->size / sizeof(float);
float *in_data = reinterpret_cast<float *>(iter->second->addr);
MS_LOG(INFO) << " weight name : " << iter->first;
for (int i = 0; i < size_data; ++i) {
for (size_t i = 0; i < size_data; ++i) {
in_data[i] = in_data[i] + noise[i + sum_size] / data_size;
}
sum_size += size_data;


+ 1
- 0
mindspore/ccsrc/fl/armour/cipher/cipher_unmask.h View File

@@ -33,6 +33,7 @@ class CipherUnmask {
public:
// initialize: get cipher_init_
CipherUnmask() { cipher_init_ = &CipherInit::GetInstance(); }
~CipherUnmask() = default;
// unmask the data by secret mask.
bool UnMask(const std::map<std::string, AddressPtr> &data);



+ 31
- 44
mindspore/ccsrc/fl/armour/secure_protocol/encrypt.cc View File

@@ -19,11 +19,11 @@
namespace mindspore {
namespace armour {
AESEncrypt::AESEncrypt(const uint8_t *key, int key_len, const uint8_t *ivec, int ivec_len, const AES_MODE mode) {
privKey = key;
privKeyLen = key_len;
iVec = ivec;
iVecLen = ivec_len;
aesMode = mode;
priv_key_ = key;
priv_key_len_ = key_len;
ivec_ = ivec;
ivec_len_ = ivec_len;
aes_mode_ = mode;
}

AESEncrypt::~AESEncrypt() {}
@@ -42,28 +42,23 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
#else
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
int ret;
if (privKey == NULL || iVec == NULL) {
if (priv_key_ == nullptr || ivec_ == nullptr) {
MS_LOG(ERROR) << "private key or init vector is invalid.";
return -1;
}
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
if (priv_key_len_ != KEY_LENGTH_16 && priv_key_len_ != KEY_LENGTH_32) {
MS_LOG(ERROR) << "key length is invalid.";
return -1;
}
if (iVecLen != AES_IV_SIZE) {
if (ivec_len_ != AES_IV_SIZE) {
MS_LOG(ERROR) << "initial vector size is invalid.";
return -1;
}
if (data == NULL || len <= 0 || encrypt_data == NULL) {
if (data == nullptr || len <= 0 || encrypt_data == nullptr || encrypt_len == nullptr) {
MS_LOG(ERROR) << "input data is invalid.";
return -1;
}
if (aesMode == AES_CBC || aesMode == AES_CTR) {
ret = evp_aes_encrypt(data, len, privKey, iVec, encrypt_data, encrypt_len);
} else {
MS_LOG(ERROR) << "Please use CBC mode or CTR mode, the other modes are not supported!";
ret = -1;
}
ret = evp_aes_encrypt(data, len, priv_key_, ivec_, encrypt_data, encrypt_len);
if (ret != 0) {
return -1;
}
@@ -72,26 +67,26 @@ int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned c

int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
int ret = 0;
if (privKey == NULL || iVec == NULL) {
if (priv_key_ == nullptr || ivec_ == nullptr) {
MS_LOG(ERROR) << "private key or init vector is invalid.";
return -1;
}
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
if (priv_key_len_ != KEY_LENGTH_16 && priv_key_len_ != KEY_LENGTH_32) {
MS_LOG(ERROR) << "key length is invalid.";
return -1;
}
if (iVecLen != AES_IV_SIZE) {
if (ivec_len_ != AES_IV_SIZE) {
MS_LOG(ERROR) << "initial vector size is invalid.";
return -1;
}
if (data == NULL || encrypt_len <= 0 || encrypt_data == NULL) {
if (data == nullptr || encrypt_len <= 0 || encrypt_data == nullptr || len == nullptr) {
MS_LOG(ERROR) << "input data is invalid.";
return -1;
}
if (aesMode == AES_CBC || aesMode == AES_CTR) {
ret = evp_aes_decrypt(encrypt_data, encrypt_len, privKey, iVec, data, len);
if (aes_mode_ == AES_CBC || aes_mode_ == AES_CTR) {
ret = evp_aes_decrypt(encrypt_data, encrypt_len, priv_key_, ivec_, data, len);
} else {
MS_LOG(ERROR) << "Please use CBC mode or CTR mode, the other modes are not supported!";
MS_LOG(ERROR) << "This encryption mode is not supported!";
}
if (ret != 1) {
return -1;
@@ -99,17 +94,16 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
return 0;
}

int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *encrypt_data, int *encrypt_len) {
const int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *encrypt_data, int *encrypt_len) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
return -1;
}
int out_len;
int ret;
if (aesMode == AES_CBC) {
switch (privKeyLen) {
if (aes_mode_ == AES_CBC) {
switch (priv_key_len_) {
case KEY_LENGTH_16:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
break;
@@ -121,13 +115,12 @@ int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_
ret = -1;
}
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
EVP_CIPHER_CTX_set_padding(ctx, EVP_PADDING_PKCS7);
} else if (aesMode == AES_CTR) {
switch (privKeyLen) {
} else if (aes_mode_ == AES_CTR) {
switch (priv_key_len_) {
case KEY_LENGTH_16:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
break;
@@ -139,25 +132,22 @@ int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_
ret = -1;
}
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
} else {
MS_LOG(ERROR) << "Unsupported AES mode";
MS_LOG(ERROR) << "Unsupported encryption mode";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptUpdate fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
*encrypt_len = out_len;
ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
@@ -166,17 +156,16 @@ int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_
return 0;
}

int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *decrypt_data, int *decrypt_len) {
const int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key,
const uint8_t *ivec, uint8_t *decrypt_data, int *decrypt_len) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
return -1;
}
int out_len;
int ret;
if (aesMode == AES_CBC) {
switch (privKeyLen) {
if (aes_mode_ == AES_CBC) {
switch (priv_key_len_) {
case KEY_LENGTH_16:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
break;
@@ -191,8 +180,8 @@ int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, cons
EVP_CIPHER_CTX_free(ctx);
return -1;
}
} else if (aesMode == AES_CTR) {
switch (privKeyLen) {
} else if (aes_mode_ == AES_CTR) {
switch (priv_key_len_) {
case KEY_LENGTH_16:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
break;
@@ -208,21 +197,19 @@ int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, cons
return -1;
}
} else {
MS_LOG(ERROR) << "Unsupported AES mode";
MS_LOG(ERROR) << "Unsupported encryption mode";
EVP_CIPHER_CTX_free(ctx);
return -1;
}

ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptUpdate fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
*decrypt_len = out_len;
ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}


+ 9
- 9
mindspore/ccsrc/fl/armour/secure_protocol/encrypt.h View File

@@ -43,15 +43,15 @@ class AESEncrypt : SymmetricEncrypt {
int DecryptData(const uint8_t *encrypt_data, const int encrypt_len, uint8_t *data, int *len);

private:
const uint8_t *privKey;
int privKeyLen;
const uint8_t *iVec;
int iVecLen;
AES_MODE aesMode;
int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *encrypt_data, int *encrypt_len);
int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *decrypt_data, int *decrypt_len);
const uint8_t *priv_key_;
int priv_key_len_;
const uint8_t *ivec_;
int ivec_len_;
AES_MODE aes_mode_;
const int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *encrypt_data, int *encrypt_len);
const int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *decrypt_data, int *decrypt_len);
};

} // namespace armour


+ 26
- 17
mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.cc View File

@@ -54,9 +54,9 @@ PrivateKey::PrivateKey(EVP_PKEY *evpKey) { evpPrivKey = evpKey; }

PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); }

int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) {
if (privKeyBytes == nullptr || len <= 0) {
MS_LOG(ERROR) << "input privKeyBytes invalid.";
int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) const {
if (privKeyBytes == nullptr || len == nullptr || evpPrivKey == nullptr) {
MS_LOG(ERROR) << "input data invalid.";
return -1;
}
if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) {
@@ -65,8 +65,8 @@ int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) {
return 0;
}

int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) {
if (pubKeyBytes == nullptr || len <= 0) {
int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) const {
if (pubKeyBytes == nullptr || len == nullptr || evpPrivKey == nullptr) {
MS_LOG(ERROR) << "input pubKeyBytes invalid.";
return -1;
}
@@ -90,11 +90,10 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
MS_LOG(ERROR) << "input salt in invalid.";
return -1;
}
EVP_PKEY_CTX *ctx;
size_t len = 0;
ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
if (!ctx) {
MS_LOG(ERROR) << "EVP_PKEY_CTX_new failed!";
EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
if (ctx == nullptr) {
MS_LOG(ERROR) << "new context failed!";
return -1;
}
if (EVP_PKEY_derive_init(ctx) <= 0) {
@@ -107,15 +106,17 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
EVP_PKEY_CTX_free(ctx);
return -1;
}
unsigned char *secret;
if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) {
MS_LOG(ERROR) << "get derive key size failed!";
EVP_PKEY_CTX_free(ctx);
return -1;
}

secret = (unsigned char *)OPENSSL_malloc(len);
if (!secret) {
if (len == 0) {
EVP_PKEY_CTX_free(ctx);
return -1;
}
uint8_t *secret = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len));
if (secret == nullptr) {
MS_LOG(ERROR) << "malloc secret memory failed!";
EVP_PKEY_CTX_free(ctx);
return -1;
@@ -142,7 +143,7 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
PrivateKey *KeyAgreement::GeneratePrivKey() {
EVP_PKEY *evpKey = NULL;
EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_X25519, NULL);
if (!pctx) {
if (pctx == nullptr) {
return NULL;
}
if (EVP_PKEY_keygen_init(pctx) <= 0) {
@@ -168,7 +169,7 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
return NULL;
}
pubKeyBytes = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len));
if (!pubKeyBytes) {
if (pubKeyBytes == nullptr) {
MS_LOG(ERROR) << "malloc secret memory failed!";
return NULL;
}
@@ -190,7 +191,11 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
return pubKey;
}

PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
PrivateKey *KeyAgreement::FromPrivateBytes(const uint8_t *data, size_t len) {
if (data == nullptr) {
MS_LOG(ERROR) << "input data is null!";
return NULL;
}
EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len);
if (evp_Key == NULL) {
MS_LOG(ERROR) << "create evp_Key from raw bytes failed!";
@@ -200,7 +205,11 @@ PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
return privKey;
}

PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) {
PublicKey *KeyAgreement::FromPublicBytes(const uint8_t *data, size_t len) {
if (data == nullptr) {
MS_LOG(ERROR) << "input data is null!";
return NULL;
}
EVP_PKEY *evp_pubKey = EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, data, len);
if (evp_pubKey == NULL) {
MS_LOG(ERROR) << "create evp_pubKey from raw bytes fail";


+ 4
- 4
mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.h View File

@@ -48,8 +48,8 @@ class PrivateKey {
~PrivateKey();
int Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len,
unsigned char *exchangeKey);
int GetPrivateBytes(size_t *len, unsigned char *priKeyBytes);
int GetPublicBytes(size_t *len, unsigned char *pubKeyBytes);
int GetPrivateBytes(size_t *len, unsigned char *priKeyBytes) const;
int GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) const;
EVP_PKEY *evpPrivKey;
};
#endif
@@ -58,8 +58,8 @@ class KeyAgreement {
public:
static PrivateKey *GeneratePrivKey();
static PublicKey *GeneratePubKey(PrivateKey *privKey);
static PrivateKey *FromPrivateBytes(unsigned char *data, int len);
static PublicKey *FromPublicBytes(unsigned char *data, int len);
static PrivateKey *FromPrivateBytes(const unsigned char *data, size_t len);
static PublicKey *FromPublicBytes(const unsigned char *data, size_t len);
static int ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len, const unsigned char *salt,
int salt_len, unsigned char *exchangeKey);
};


+ 10
- 3
mindspore/ccsrc/fl/armour/secure_protocol/secret_sharing.cc View File

@@ -18,9 +18,9 @@

namespace mindspore {
namespace armour {
void secure_zero(unsigned char *s, size_t n) {
volatile unsigned char *p = s;
if (p)
void secure_zero(uint8_t *s, size_t n) {
volatile uint8_t *p = s;
if (p != nullptr)
while (n--) *p++ = '\0';
}

@@ -215,17 +215,24 @@ int SecretSharing::Combine(size_t k, const std::vector<Share *> &shares, uint8_t
}
}

if (ret == -1) {
BN_clear_free(tmp);
break;
}
(void)BN_mod_inverse(tmp, denses[j], this->bn_prim_, ctx);
if (!field_mult(tmp, tmp, nums[j], ctx)) {
ret = -1;
BN_clear_free(tmp);
break;
}
if (!field_mult(tmp, tmp, y[j], ctx)) {
ret = -1;
BN_clear_free(tmp);
break;
}
if (!field_add(sum, sum, tmp, ctx)) {
ret = -1;
BN_clear_free(tmp);
break;
}
BN_clear_free(tmp);


+ 0
- 1
mindspore/ccsrc/fl/server/common.h View File

@@ -206,7 +206,6 @@ constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
constexpr auto kCtxClientsKeys = "clients_keys";
constexpr auto kCtxClientNoises = "clients_noises";
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";


+ 5
- 4
mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc View File

@@ -54,7 +54,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
for (size_t i = 0; i < IntToSize(client_list_pb.fl_id_size()); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) {
@@ -87,7 +87,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
return true;
return false;
}
MS_LOG(INFO) << "send clients_list succeed!";
MS_LOG(INFO) << "UpdateModel client list: ";
@@ -100,7 +100,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
return true;
}

bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
@@ -148,6 +148,7 @@ bool ClientListKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Get Client list kernel reset!";
DistributedCountService::GetInstance().ResetCounter(name_);
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxGetUpdateModelClientList);
StopTimer();
return true;
}
@@ -170,7 +171,7 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_clients(clients_fb);
rsp_builder.add_iteration(iteration);
rsp_builder.add_iteration(SizeToInt(iteration));
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.h View File

@@ -34,7 +34,7 @@ class ClientListKernel : public RoundKernel {
ClientListKernel() = default;
~ClientListKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void BuildClientListRsp(std::shared_ptr<server::FBBuilder> client_list_resp_builder,


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc View File

@@ -65,7 +65,7 @@ bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &
return true;
}

bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching ExchangeKey kernel.";
bool response = false;


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.h View File

@@ -35,7 +35,7 @@ class ExchangeKeysKernel : public RoundKernel {
ExchangeKeysKernel() = default;
~ExchangeKeysKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;



+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc View File

@@ -51,7 +51,7 @@ bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const
return true;
}

bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching GetKeys kernel.";
bool response = false;


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.h View File

@@ -35,7 +35,7 @@ class GetKeysKernel : public RoundKernel {
GetKeysKernel() = default;
~GetKeysKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;



+ 3
- 1
mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc View File

@@ -18,6 +18,8 @@
#include <vector>
#include <memory>
#include <string>
#include <map>
#include <utility>
#include "fl/armour/cipher/cipher_shares.h"

namespace mindspore {
@@ -52,7 +54,7 @@ bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
return true;
}

bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.h View File

@@ -34,7 +34,7 @@ class GetSecretsKernel : public RoundKernel {
GetSecretsKernel() = default;
~GetSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;



+ 2
- 2
mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc View File

@@ -50,7 +50,7 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
{first_cnt_handler, last_cnt_handler});
}

bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
@@ -80,7 +80,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();

for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) {
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
}



+ 2
- 2
mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.h View File

@@ -36,7 +36,7 @@ class ReconstructSecretsKernel : public RoundKernel {
~ReconstructSecretsKernel() override = default;

void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
@@ -44,7 +44,7 @@ class ReconstructSecretsKernel : public RoundKernel {
private:
std::string name_unmask_;
Executor *executor_;
size_t iteration_time_window_;
size_t iteration_time_window_{0};
armour::CipherReconStruct cipher_reconstruct_;
};
} // namespace kernel


+ 2
- 3
mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc View File

@@ -38,8 +38,7 @@ void ShareSecretsKernel::InitKernel(size_t) {

bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestShareSecrets *share_secrets_req,
const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(share_secrets_req, false);
const size_t iter_num) {
if (!DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str())) {
std::string reason = "Counting for share secret kernel request failed. Please retry later.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
@@ -50,7 +49,7 @@ bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &
return true;
}

bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();


+ 2
- 2
mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.h View File

@@ -35,7 +35,7 @@ class ShareSecretsKernel : public RoundKernel {
ShareSecretsKernel() = default;
~ShareSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;

@@ -44,7 +44,7 @@ class ShareSecretsKernel : public RoundKernel {
size_t iteration_time_window_;
armour::CipherShares *cipher_share_;
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
const int iter_num);
const size_t iter_num);
};
} // namespace kernel
} // namespace server


+ 0
- 1
mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc View File

@@ -118,7 +118,6 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}

GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}


+ 18
- 21
mindspore/ccsrc/fl/server/server.cc View File

@@ -245,38 +245,18 @@ void Server::InitIteration() {
void Server::InitCipher() {
#ifdef ENABLE_ARMOUR
cipher_init_ = &armour::CipherInit::GetInstance();

int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
const int cipher_g = 1;
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
float dp_eps = ps::PSContext::instance()->dp_eps();
float dp_delta = ps::PSContext::instance()->dp_delta();
float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();

BIGNUM *prim = BN_new();
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "new bn failed";
}

mindspore::armour::GetPrime(prim);

MS_LOG(INFO) << "prime" << BN_bn2hex(prim);
(void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(cipher_prime));
if (prim != nullptr) {
BN_clear_free(prim);
}

mindspore::armour::CipherPublicPara param;
param.g = cipher_g;
param.t = cipher_t;
int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
ret = memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
@@ -285,6 +265,23 @@ void Server::InitCipher() {
param.dp_eps = dp_eps;
param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = encrypt_type;

BIGNUM *prim = BN_new();
if (prim == NULL) {
MS_LOG(EXCEPTION) << "new bn failed.";
ret = -1;
} else {
ret = mindspore::armour::GetPrime(prim);
}
if (ret == 0) {
(void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(param.prime));
} else {
MS_LOG(EXCEPTION) << "Get prime failed.";
}
if (prim != NULL) {
BN_clear_free(prim);
}

cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
cipher_reconstruct_secrets_up_cnt_);


Loading…
Cancel
Save