| @@ -51,6 +51,8 @@ if(NOT ENABLE_CPU OR WIN32) | |||
| list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_reconstruct.cc") | |||
| list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_shares.cc") | |||
| list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc") | |||
| list(REMOVE_ITEM _FL_SRC_FILES "compression/decode_executor.cc") | |||
| list(REMOVE_ITEM _FL_SRC_FILES "compression/encode_executor.cc") | |||
| endif() | |||
| if(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| @@ -0,0 +1,150 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "fl/compression/decode_executor.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace compression { | |||
| std::vector<int> DecodeExecutor::ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num) { | |||
| static int multiplier = 2147483647; | |||
| static double increment = 4294967294.0; | |||
| static int modulo = 48271; | |||
| size_t retain_num = size_t(static_cast<float>(param_num) * upload_sparse_rate); | |||
| if (retain_num == 0) { | |||
| MS_LOG(WARNING) << "The retain_num is 0, and upload_sparse_rate is too small."; | |||
| } | |||
| std::vector<int> mask_array(param_num, 0); | |||
| for (size_t i = 0; i < retain_num; ++i) { | |||
| mask_array[i] = 1; | |||
| } | |||
| seed = ((seed + multiplier) * modulo) % multiplier; | |||
| for (size_t i = 0; i < param_num; ++i) { | |||
| // generate random number in (0, 1) | |||
| double rand = static_cast<double>(seed) / increment + 0.5; | |||
| // update seed | |||
| seed = (seed * modulo) % multiplier; | |||
| size_t j = size_t(rand * static_cast<double>(param_num - i)) + i; | |||
| int temp = mask_array[i]; | |||
| mask_array[i] = mask_array[j]; | |||
| mask_array[j] = temp; | |||
| } | |||
| return mask_array; | |||
| } | |||
| bool DecodeExecutor::DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map, | |||
| const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits, | |||
| float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec, | |||
| size_t data_size) { | |||
| std::vector<std::vector<float>> decompress_feature_maps; | |||
| // origin parameters | |||
| std::vector<size_t> shape_vec; | |||
| size_t param_num = 0; | |||
| const auto &iter_to_model = mindspore::fl::server::ModelStore::GetInstance().iteration_to_model(); | |||
| size_t latest_iter_num = iter_to_model.rbegin()->first; | |||
| std::map<std::string, AddressPtr> feature_maps = | |||
| mindspore::fl::server::ModelStore::GetInstance().GetModelByIterNum(latest_iter_num); | |||
| // get shape vector and number of upload parameters | |||
| for (const auto &name : name_vec) { | |||
| size_t shape = feature_maps[name]->size / sizeof(float); | |||
| shape_vec.emplace_back(shape); | |||
| param_num += shape; | |||
| } | |||
| MS_LOG(DEBUG) << "Compression get last weights success!"; | |||
| // quant decode | |||
| auto temp1 = static_cast<float>(1 << num_bits) - 1.0f; | |||
| auto temp2 = static_cast<float>(1 << (num_bits - 1)); | |||
| std::vector<float> de_min_max_feature_map; | |||
| for (auto compress_feature_map : compress_feature_maps) { | |||
| float min_val = compress_feature_map.min_val; | |||
| float max_val = compress_feature_map.max_val; | |||
| float scale_val = static_cast<float>(max_val - min_val) / temp1 + 1e-10f; | |||
| size_t size = compress_feature_map.compress_data.size(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| de_min_max_feature_map.emplace_back( | |||
| (static_cast<float>(compress_feature_map.compress_data[i]) + temp2) * scale_val + min_val); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Compression quant decode success!"; | |||
| // sparse decode | |||
| std::vector<int> mask_array = ConstructMaskArray(seed, upload_sparse_rate, param_num); | |||
| size_t index = 0; | |||
| size_t de_min_max_feature_map_index = 0; | |||
| for (const auto &shape : shape_vec) { | |||
| std::vector<float> feature_map(shape); | |||
| for (size_t i = 0; i < shape; ++i) { | |||
| if (index >= mask_array.size()) { | |||
| MS_LOG(WARNING) << "The mask_array and parameter shape is not matched."; | |||
| return false; | |||
| } | |||
| if (mask_array[index] == 1) { | |||
| if (de_min_max_feature_map_index >= de_min_max_feature_map.size()) { | |||
| MS_LOG(WARNING) << "The number of upload parameters is too small."; | |||
| return false; | |||
| } | |||
| feature_map[i] = de_min_max_feature_map[de_min_max_feature_map_index]; | |||
| de_min_max_feature_map_index += 1; | |||
| } else { | |||
| feature_map[i] = 0.0f; | |||
| } | |||
| index += 1; | |||
| } | |||
| decompress_feature_maps.emplace_back(feature_map); | |||
| } | |||
| MS_LOG(DEBUG) << "Compression sparse decode success!"; | |||
| // difference decode | |||
| for (size_t i = 0; i < decompress_feature_maps.size(); ++i) { | |||
| size_t feature_size = decompress_feature_maps[i].size(); | |||
| std::string name = name_vec[i]; | |||
| float *weight_data = reinterpret_cast<float *>(feature_maps[name]->addr); | |||
| auto &weight_item = (*weight_map)[name]; | |||
| weight_item.resize(feature_size); | |||
| for (size_t j = 0; j < feature_size; ++j) { | |||
| weight_item[j] = decompress_feature_maps[i][j] + data_size * weight_data[j]; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Compression difference decode success!"; | |||
| return true; | |||
| } | |||
| bool DecodeExecutor::Decode(std::map<std::string, std::vector<float>> *weight_map, | |||
| const std::vector<CompressFeatureMap> &compress_feature_maps, | |||
| schema::CompressType upload_compress_type, float upload_sparse_rate, int seed, | |||
| const std::vector<std::string> &name_vec, size_t data_size) { | |||
| if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) { | |||
| return DeQuantSparseDiff(weight_map, compress_feature_maps, 8, upload_sparse_rate, seed, name_vec, data_size); | |||
| } | |||
| return false; | |||
| } | |||
| schema::CompressType DecodeExecutor::GetCompressType(schema::CompressType upload_compress_type) { | |||
| if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) { | |||
| MS_LOG(DEBUG) << "This upload compress type is DIFF_SPARSE_QUANT."; | |||
| return schema::CompressType_DIFF_SPARSE_QUANT; | |||
| } | |||
| MS_LOG(DEBUG) << "This upload compress type is NO_COMPRESS."; | |||
| return schema::CompressType_NO_COMPRESS; | |||
| } | |||
| } // namespace compression | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_ | |||
| #define MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <cstdio> | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <functional> | |||
| #include <algorithm> | |||
| #include <regex> | |||
| #include <map> | |||
| #include <utility> | |||
| #include "proto/comm.pb.h" | |||
| #include "schema/fl_job_generated.h" | |||
| #include "schema/cipher_generated.h" | |||
| #include "fl/server/model_store.h" | |||
| #include "fl/server/common.h" | |||
| #include "ps/ps_context.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace compression { | |||
| struct CompressFeatureMap { | |||
| std::string weight_fullname; | |||
| std::vector<int8_t> compress_data; | |||
| float min_val; | |||
| float max_val; | |||
| }; | |||
| class DecodeExecutor { | |||
| public: | |||
| static DecodeExecutor &GetInstance() { | |||
| static DecodeExecutor instance; | |||
| return instance; | |||
| } | |||
| // construct mask array for random sparse | |||
| std::vector<int> ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num); | |||
| // decode min_max quantization and random sparse and parameter difference | |||
| bool DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map, | |||
| const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits, | |||
| float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec, | |||
| size_t data_size); | |||
| // decode | |||
| bool Decode(std::map<std::string, std::vector<float>> *weight_map, | |||
| const std::vector<CompressFeatureMap> &compress_feature_maps, schema::CompressType upload_compress_type, | |||
| float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec, size_t data_size); | |||
| schema::CompressType GetCompressType(schema::CompressType upload_compress_type); | |||
| }; | |||
| } // namespace compression | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_ | |||
| @@ -0,0 +1,102 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "fl/compression/encode_executor.h" | |||
| #include <arpa/inet.h> | |||
| #include <cstdio> | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <functional> | |||
| #include <algorithm> | |||
| #include <regex> | |||
| #include <map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "fl/server/common.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace compression { | |||
| bool CompressExecutor::EnableCompressWeight(const schema::CompressType compressType) { | |||
| return kCompressTypeMap.count(compressType) > 0; | |||
| } | |||
| bool CompressExecutor::construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights, | |||
| std::map<std::string, std::vector<float>> feature_maps, | |||
| const schema::CompressType compressType) { | |||
| if (compressType == schema::CompressType_QUANT) { | |||
| return quant_min_max(compressWeights, feature_maps, kCompressTypeMap.at(compressType)); | |||
| } | |||
| return false; | |||
| } | |||
| bool CompressExecutor::quant_min_max(std::map<std::string, CompressWeight> *compressWeights, | |||
| std::map<std::string, std::vector<float>> feature_maps, size_t num_bits) { | |||
| auto temp1 = static_cast<float>(1 << num_bits) - 1.0f; | |||
| auto temp2 = static_cast<float>(1 << (num_bits - 1)); | |||
| for (const auto &feature_map : feature_maps) { | |||
| std::string weight_name = feature_map.first; | |||
| float min_value = 1e10f; | |||
| float max_value = -min_value; | |||
| for (const auto &feature : feature_map.second) { | |||
| if (feature > max_value) { | |||
| max_value = feature; | |||
| } | |||
| if (feature < min_value) { | |||
| min_value = feature; | |||
| } | |||
| } | |||
| float scale_value = (max_value - min_value) / temp1 + 1e-10f; | |||
| size_t size = feature_map.second.size(); | |||
| if (size == 0) { | |||
| MS_LOG(WARNING) << "The size of parameters is zero."; | |||
| return false; | |||
| } | |||
| CompressWeight compressWeight; | |||
| for (size_t i = 0; i < size; ++i) { | |||
| auto round_data = round((feature_map.second[i] - min_value) / scale_value - temp2); | |||
| // bit pack can be implemented here in the future | |||
| auto int8_data = int8_t(round_data); | |||
| compressWeight.compress_data.emplace_back(int8_data); | |||
| } | |||
| compressWeight.min_val = min_value; | |||
| compressWeight.max_val = max_value; | |||
| compressWeight.compress_data_len = size; | |||
| (*compressWeights)[weight_name] = compressWeight; | |||
| } | |||
| return true; | |||
| } | |||
| schema::CompressType CompressExecutor::GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types) { | |||
| schema::CompressType compressType = schema::CompressType_NO_COMPRESS; | |||
| if (download_compress_types == nullptr) { | |||
| MS_LOG(DEBUG) << "The client does not support current download compress type."; | |||
| } else { | |||
| for (size_t i = 0; i < download_compress_types->size(); ++i) { | |||
| auto download_compress_type = download_compress_types->Get(i); | |||
| if (download_compress_type == schema::CompressType_QUANT) { | |||
| compressType = schema::CompressType_QUANT; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| return compressType; | |||
| } | |||
| } // namespace compression | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_ | |||
| #define MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "proto/comm.pb.h" | |||
| #include "schema/fl_job_generated.h" | |||
| #include "schema/cipher_generated.h" | |||
| #include "fl/armour/secure_protocol/key_agreement.h" | |||
| #include "ps/ps_context.h" | |||
| #include "ps/core/worker_node.h" | |||
| #include "ps/core/cluster_metadata.h" | |||
| #include "ps/core/communicator/tcp_communicator.h" | |||
| #include "fl/server/common.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace compression { | |||
| // compress type map: schema::CompressType -> num bits | |||
| const std::map<schema::CompressType, size_t> kCompressTypeMap = {{schema::CompressType_QUANT, 8}}; | |||
| struct CompressWeight { | |||
| std::vector<int8_t> compress_data; | |||
| size_t compress_data_len; | |||
| float min_val; | |||
| float max_val; | |||
| }; | |||
| class CompressExecutor { | |||
| public: | |||
| static CompressExecutor &GetInstance() { | |||
| static CompressExecutor instance; | |||
| return instance; | |||
| } | |||
| bool EnableCompressWeight(const schema::CompressType compressType); | |||
| bool construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights, | |||
| std::map<std::string, std::vector<float>> feature_maps, | |||
| const schema::CompressType compressType); | |||
| bool quant_min_max(std::map<std::string, CompressWeight> *compressWeights, | |||
| std::map<std::string, std::vector<float>> feature_maps, size_t num_bits); | |||
| schema::CompressType GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types); | |||
| }; | |||
| } // namespace compression | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_ | |||
| @@ -149,6 +149,11 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum"; | |||
| constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum"; | |||
| constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum"; | |||
| constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum"; | |||
| constexpr auto kMinVal = "min_val"; | |||
| constexpr auto kMaxVal = "max_val"; | |||
| constexpr auto kQuant = "QUANT"; | |||
| constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT"; | |||
| constexpr auto kNoCompress = "NO_COMPRESS"; | |||
| // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is | |||
| // launched. | |||
| @@ -588,6 +588,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { | |||
| if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) { | |||
| ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model); | |||
| iteration_result_ = IterationResult::kSuccess; | |||
| MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; | |||
| } else { | |||
| @@ -599,6 +600,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { | |||
| size_t latest_iter_num = iter_to_model.rbegin()->first; | |||
| const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num); | |||
| ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model); | |||
| iteration_result_ = IterationResult::kFail; | |||
| MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; | |||
| } | |||
| @@ -92,7 +92,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, | |||
| return; | |||
| } | |||
| auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp); | |||
| std::map<std::string, AddressPtr> feature_maps; | |||
| std::map<std::string, AddressPtr> feature_maps = {}; | |||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| size_t get_model_iter = IntToSize(get_model_req->iteration()); | |||
| const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); | |||
| @@ -110,6 +110,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return; | |||
| } | |||
| IncreaseAcceptClientNum(); | |||
| auto real_get_model_iter = get_model_iter; | |||
| if (iter_to_model.count(get_model_iter) == 0) { | |||
| @@ -118,12 +119,37 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, | |||
| << " is invalid. Current iteration is " << std::to_string(current_iter); | |||
| real_get_model_iter = latest_iter_num; | |||
| } | |||
| auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter); | |||
| auto download_compress_types = get_model_req->download_compress_types(); | |||
| schema::CompressType compressType = | |||
| mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types); | |||
| std::string compress_type; | |||
| if (compressType == schema::CompressType_QUANT) { | |||
| compress_type = kQuant; | |||
| } else { | |||
| compress_type = kNoCompress; | |||
| } | |||
| auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter, compress_type); | |||
| if (cache == nullptr) { | |||
| feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter); | |||
| // Only download compress weights if client support. | |||
| std::map<std::string, AddressPtr> compress_feature_maps = {}; | |||
| if (compressType == schema::CompressType_NO_COMPRESS) { | |||
| feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter); | |||
| } else { | |||
| auto compressExecutor = mindspore::fl::compression::CompressExecutor::GetInstance(); | |||
| if (compressExecutor.EnableCompressWeight(compressType)) { | |||
| const auto &iter_to_compress_model = ModelStore::GetInstance().iteration_to_compress_model(); | |||
| if (iter_to_compress_model.count(get_model_iter) == 0) { | |||
| MS_LOG(DEBUG) << "The iteration of GetCompressModel request " << std::to_string(get_model_iter) | |||
| << " is invalid. Current iteration is " << std::to_string(current_iter); | |||
| compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(latest_iter_num, compressType); | |||
| } else { | |||
| compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(get_model_iter, compressType); | |||
| } | |||
| } | |||
| } | |||
| BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter), | |||
| current_iter, feature_maps, std::to_string(next_req_time)); | |||
| cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter, | |||
| current_iter, feature_maps, std::to_string(next_req_time), compressType, compress_feature_maps); | |||
| cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter, compress_type, | |||
| fbb->GetBufferPointer(), fbb->GetSize()); | |||
| if (cache == nullptr) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| @@ -131,7 +157,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, | |||
| } | |||
| } | |||
| SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache); | |||
| MS_LOG(DEBUG) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() | |||
| MS_LOG(DEBUG) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() | |||
| << ", next request time is " << next_req_time << ", current iteration is " << current_iter; | |||
| return; | |||
| } | |||
| @@ -139,7 +165,8 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, | |||
| void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const size_t iter, | |||
| const std::map<std::string, AddressPtr> &feature_maps, | |||
| const std::string ×tamp) { | |||
| const std::string ×tamp, const schema::CompressType &compressType, | |||
| const std::map<std::string, AddressPtr> &compress_feature_maps) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(ERROR) << "Input fbb is nullptr."; | |||
| return; | |||
| @@ -156,12 +183,40 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con | |||
| } | |||
| auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); | |||
| // construct compress feature maps with fbs | |||
| std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps; | |||
| for (const auto &compress_feature_map : compress_feature_maps) { | |||
| if (compress_feature_map.first.find(kMinVal) != string::npos || | |||
| compress_feature_map.first.find(kMaxVal) != string::npos) { | |||
| continue; | |||
| } | |||
| auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first); | |||
| auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr), | |||
| compress_feature_map.second->size / sizeof(int8_t)); | |||
| const std::string min_val_name = compress_feature_map.first + "." + kMinVal; | |||
| const std::string max_val_name = compress_feature_map.first + "." + kMaxVal; | |||
| const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name); | |||
| const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name); | |||
| float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr); | |||
| float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr); | |||
| auto fbs_compress_feature_map = schema::CreateCompressFeatureMap( | |||
| *(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr); | |||
| fbs_compress_feature_maps.push_back(fbs_compress_feature_map); | |||
| } | |||
| auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps); | |||
| schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get())); | |||
| rsp_get_model_builder.add_retcode(static_cast<int>(retcode)); | |||
| rsp_get_model_builder.add_reason(fbs_reason); | |||
| rsp_get_model_builder.add_iteration(static_cast<int>(iter)); | |||
| rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector); | |||
| rsp_get_model_builder.add_timestamp(fbs_timestamp); | |||
| rsp_get_model_builder.add_download_compress_type(compressType); | |||
| rsp_get_model_builder.add_compress_feature_map(fbs_compress_feature_maps_vector); | |||
| auto rsp_get_model = rsp_get_model_builder.Finish(); | |||
| fbb->Finish(rsp_get_model); | |||
| return; | |||
| @@ -25,6 +25,7 @@ | |||
| #include "fl/server/executor.h" | |||
| #include "fl/server/kernel/round/round_kernel.h" | |||
| #include "fl/server/kernel/round/round_kernel_factory.h" | |||
| #include "fl/compression/encode_executor.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| @@ -44,7 +45,9 @@ class GetModelKernel : public RoundKernel { | |||
| void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<ps::core::MessageHandler> &message); | |||
| void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const size_t iter, | |||
| const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp); | |||
| const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp, | |||
| const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS, | |||
| const std::map<std::string, AddressPtr> &compress_feature_maps = {}); | |||
| // The executor is for getting model for getModel request. | |||
| Executor *executor_; | |||
| @@ -126,10 +126,19 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, | |||
| IncreaseAcceptClientNum(); | |||
| auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| auto last_iteration = curr_iter_num - 1; | |||
| auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration); | |||
| auto download_compress_types = start_fl_job_req->download_compress_types(); | |||
| schema::CompressType compressType = | |||
| mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types); | |||
| std::string compress_type; | |||
| if (compressType == schema::CompressType_QUANT) { | |||
| compress_type = kQuant; | |||
| } else { | |||
| compress_type = kNoCompress; | |||
| } | |||
| auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration, compress_type); | |||
| if (cache == nullptr) { | |||
| StartFLJob(fbb); | |||
| cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration, | |||
| StartFLJob(fbb, device_meta, start_fl_job_req); | |||
| cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration, compress_type, | |||
| fbb->GetBufferPointer(), fbb->GetSize()); | |||
| if (cache == nullptr) { | |||
| SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| @@ -303,22 +312,40 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> | |||
| return ResultCode::kSuccess; | |||
| } | |||
| void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb) { | |||
| void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &, | |||
| const schema::RequestFLJob *start_fl_job_req) { | |||
| size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1; | |||
| auto feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration); | |||
| if (feature_maps.empty()) { | |||
| MS_LOG(WARNING) << "The feature map for startFLJob is empty."; | |||
| std::map<std::string, AddressPtr> feature_maps = {}; | |||
| std::map<std::string, AddressPtr> compress_feature_maps = {}; | |||
| // Only download compress weights if client support. | |||
| auto download_compress_types = start_fl_job_req->download_compress_types(); | |||
| schema::CompressType compressType = | |||
| mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types); | |||
| if (compressType == schema::CompressType_NO_COMPRESS) { | |||
| feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration); | |||
| if (feature_maps.empty()) { | |||
| MS_LOG(WARNING) << "The feature map for startFLJob is empty."; | |||
| } | |||
| } else { | |||
| if (mindspore::fl::compression::CompressExecutor::GetInstance().EnableCompressWeight(compressType)) { | |||
| compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(last_iteration, compressType); | |||
| } | |||
| } | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), | |||
| feature_maps); | |||
| feature_maps, compressType, compress_feature_maps); | |||
| return; | |||
| } | |||
| void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const bool is_selected, | |||
| const std::string &next_req_time, | |||
| std::map<std::string, AddressPtr> feature_maps) { | |||
| const std::map<std::string, AddressPtr> &feature_maps, | |||
| const schema::CompressType &compressType, | |||
| const std::map<std::string, AddressPtr> &compress_feature_maps) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(WARNING) << "Input fbb is nullptr."; | |||
| return; | |||
| @@ -350,6 +377,12 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, | |||
| auto cipher_public_params = | |||
| schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params); | |||
| #endif | |||
| schema::CompressType upload_compress_type; | |||
| if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) { | |||
| upload_compress_type = schema::CompressType_DIFF_SPARSE_QUANT; | |||
| } else { | |||
| upload_compress_type = schema::CompressType_NO_COMPRESS; | |||
| } | |||
| schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); | |||
| fl_plan_builder.add_fl_name(fbs_fl_name); | |||
| @@ -375,6 +408,33 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, | |||
| } | |||
| auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); | |||
| // construct compress feature maps with fbs | |||
| std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps; | |||
| for (const auto &compress_feature_map : compress_feature_maps) { | |||
| if (compressType == schema::CompressType_QUANT) { | |||
| if (compress_feature_map.first.find(kMinVal) != string::npos || | |||
| compress_feature_map.first.find(kMaxVal) != string::npos) { | |||
| continue; | |||
| } | |||
| auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first); | |||
| auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr), | |||
| compress_feature_map.second->size / sizeof(int8_t)); | |||
| const std::string min_val_name = compress_feature_map.first + "." + kMinVal; | |||
| const std::string max_val_name = compress_feature_map.first + "." + kMaxVal; | |||
| const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name); | |||
| const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name); | |||
| float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr); | |||
| float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr); | |||
| auto fbs_compress_feature_map = schema::CreateCompressFeatureMap( | |||
| *(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr); | |||
| fbs_compress_feature_maps.push_back(fbs_compress_feature_map); | |||
| } | |||
| } | |||
| auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps); | |||
| schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get())); | |||
| rsp_fl_job_builder.add_retcode(static_cast<int>(retcode)); | |||
| rsp_fl_job_builder.add_reason(fbs_reason); | |||
| @@ -383,6 +443,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, | |||
| rsp_fl_job_builder.add_next_req_time(fbs_next_req_time); | |||
| rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan); | |||
| rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector); | |||
| rsp_fl_job_builder.add_download_compress_type(compressType); | |||
| rsp_fl_job_builder.add_compress_feature_map(fbs_compress_feature_maps_vector); | |||
| rsp_fl_job_builder.add_upload_compress_type(upload_compress_type); | |||
| rsp_fl_job_builder.add_upload_sparse_rate(ps::PSContext::instance()->upload_sparse_rate()); | |||
| auto rsp_fl_job = rsp_fl_job_builder.Finish(); | |||
| fbb->Finish(rsp_fl_job); | |||
| return; | |||
| @@ -25,6 +25,9 @@ | |||
| #include "fl/server/executor.h" | |||
| #include "fl/server/kernel/round/round_kernel.h" | |||
| #include "fl/server/kernel/round/round_kernel_factory.h" | |||
| #include "schema/fl_job_generated.h" | |||
| #include "schema/cipher_generated.h" | |||
| #include "fl/compression/encode_executor.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| @@ -56,7 +59,8 @@ class StartFLJobKernel : public RoundKernel { | |||
| // Distributed count service counts for startFLJob. | |||
| ResultCode CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req); | |||
| void StartFLJob(const std::shared_ptr<FBBuilder> &fbb); | |||
| void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta, | |||
| const schema::RequestFLJob *start_fl_job_req); | |||
| bool JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req); | |||
| @@ -65,7 +69,9 @@ class StartFLJobKernel : public RoundKernel { | |||
| // Build response for startFLJob round no matter success or failure. | |||
| void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const bool is_selected, const std::string &next_req_time, | |||
| std::map<std::string, AddressPtr> feature_maps = {}); | |||
| const std::map<std::string, AddressPtr> &feature_maps = {}, | |||
| const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS, | |||
| const std::map<std::string, AddressPtr> &compress_feature_maps = {}); | |||
| // The executor is for getting the initial model for startFLJob request. | |||
| Executor *executor_; | |||
| @@ -201,23 +201,27 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel | |||
| } | |||
| std::unordered_map<std::string, size_t> feature_map; | |||
| auto upload_feature_map = update_model_req->feature_map(); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail); | |||
| for (uint32_t i = 0; i < upload_feature_map->size(); i++) { | |||
| const auto &item = upload_feature_map->Get(i); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail); | |||
| std::string weight_full_name = item->weight_fullname()->str(); | |||
| size_t weight_size = item->data()->size() * sizeof(float); | |||
| feature_map[weight_full_name] = weight_size; | |||
| if (ps::PSContext::instance()->upload_compress_type() != kDiffSparseQuant) { | |||
| auto upload_feature_map = update_model_req->feature_map(); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail); | |||
| for (uint32_t i = 0; i < upload_feature_map->size(); i++) { | |||
| const auto &item = upload_feature_map->Get(i); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail); | |||
| std::string weight_full_name = item->weight_fullname()->str(); | |||
| size_t weight_size = item->data()->size() * sizeof(float); | |||
| feature_map[weight_full_name] = weight_size; | |||
| } | |||
| } | |||
| bool verifyFeatureMapIsSuccess; | |||
| if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail); | |||
| verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req); | |||
| } else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) { | |||
| verifyFeatureMapIsSuccess = VerifyUploadCompressFeatureMap(update_model_req); | |||
| } else { | |||
| verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map); | |||
| } | |||
| @@ -280,6 +284,45 @@ bool UpdateModelKernel::VerifySignDSFeatureMap(const std::unordered_map<std::str | |||
| return true; | |||
| } | |||
| bool UpdateModelKernel::VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req) { | |||
| auto &aggregation_feature_map_ = LocalMetaStore::GetInstance().aggregation_feature_map(); | |||
| auto upload_sparse_rate = update_model_req->upload_sparse_rate(); | |||
| if (upload_sparse_rate != ps::PSContext::instance()->upload_sparse_rate()) { | |||
| MS_LOG(WARNING) << "The upload_sparse_rate must be equal to the setting in context."; | |||
| return false; | |||
| } | |||
| auto fbs_name_vec = update_model_req->name_vec(); | |||
| if (fbs_name_vec == nullptr) { | |||
| MS_LOG(WARNING) << "The name_vec is null."; | |||
| return false; | |||
| } | |||
| if (fbs_name_vec->size() == 0) { | |||
| MS_LOG(WARNING) << "The size of name_vec must be larger than 0."; | |||
| return false; | |||
| } | |||
| if (fbs_name_vec->size() > aggregation_feature_map_.size()) { | |||
| MS_LOG(WARNING) << "The size of name_vec must be smaller than model in server."; | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < fbs_name_vec->size(); ++i) { | |||
| std::string name = fbs_name_vec->Get(i)->str(); | |||
| if (aggregation_feature_map_.count(name) == 0) { | |||
| MS_LOG(WARNING) << "The upload name: " << name << " is not in model in server."; | |||
| return false; | |||
| } | |||
| } | |||
| auto fbs_compress_feature_map = update_model_req->compress_feature_map(); | |||
| if (fbs_compress_feature_map == nullptr) { | |||
| MS_LOG(WARNING) << "The upload compress feature map is null."; | |||
| return false; | |||
| } | |||
| if (fbs_compress_feature_map->size() == 0) { | |||
| MS_LOG(WARNING) << "The upload compress feature map is empty."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); | |||
| @@ -292,6 +335,8 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| std::map<std::string, UploadData> feature_map; | |||
| if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType) { | |||
| feature_map = ParseSignDSFeatureMap(update_model_req, data_size, &weight_map); | |||
| } else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) { | |||
| feature_map = ParseUploadCompressFeatureMap(update_model_req, data_size, &weight_map); | |||
| } else { | |||
| feature_map = ParseFeatureMap(update_model_req); | |||
| } | |||
| @@ -397,6 +442,89 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseSignDSFeatureMap( | |||
| return feature_map; | |||
| } | |||
| std::map<std::string, UploadData> UpdateModelKernel::ParseUploadCompressFeatureMap( | |||
| const schema::RequestUpdateModel *update_model_req, size_t data_size, | |||
| std::map<std::string, std::vector<float>> *weight_map) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {}); | |||
| std::map<std::string, UploadData> feature_map; | |||
| schema::CompressType upload_compress_type = update_model_req->upload_compress_type(); | |||
| upload_compress_type = | |||
| mindspore::fl::compression::DecodeExecutor::GetInstance().GetCompressType(upload_compress_type); | |||
| MS_LOG(INFO) << "This schema upload compress type is: " << upload_compress_type; | |||
| if (upload_compress_type != schema::CompressType_NO_COMPRESS) { | |||
| MS_LOG(INFO) << "This upload compress type is DIFF_SPARSE_QUANT."; | |||
| feature_map = DecodeFeatureMap(weight_map, update_model_req, upload_compress_type, data_size); | |||
| return feature_map; | |||
| } | |||
| MS_LOG(INFO) << "This upload compress type is NO_COMPRESS."; | |||
| // Some clients upload origin weights. | |||
| auto fbs_feature_map = update_model_req->feature_map(); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map); | |||
| for (uint32_t i = 0; i < fbs_feature_map->size(); i++) { | |||
| std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str(); | |||
| float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data()); | |||
| size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float); | |||
| UploadData upload_data; | |||
| upload_data[kNewWeight].addr = weight_data; | |||
| upload_data[kNewWeight].size = weight_size; | |||
| feature_map[weight_full_name] = upload_data; | |||
| } | |||
| return feature_map; | |||
| } | |||
| std::map<std::string, UploadData> UpdateModelKernel::DecodeFeatureMap( | |||
| std::map<std::string, std::vector<float>> *weight_map, const schema::RequestUpdateModel *update_model_req, | |||
| schema::CompressType upload_compress_type, size_t data_size) { | |||
| std::map<std::string, UploadData> feature_map; | |||
| // Get and set decode hyper parameters. | |||
| auto seed = update_model_req->iteration(); | |||
| MS_LOG(INFO) << "The seed for compression is: " << seed; | |||
| auto upload_sparse_rate = update_model_req->upload_sparse_rate(); | |||
| MS_LOG(INFO) << "The upload_sparse_rate for compression is: " << upload_sparse_rate; | |||
| // Get name vector. | |||
| auto fbs_name_vec = update_model_req->name_vec(); | |||
| std::vector<std::string> name_vec; | |||
| for (size_t i = 0; i < fbs_name_vec->size(); ++i) { | |||
| name_vec.emplace_back(fbs_name_vec->Get(i)->str()); | |||
| } | |||
| // Parameter process for decode. | |||
| auto fbs_compress_feature_map = update_model_req->compress_feature_map(); | |||
| std::vector<mindspore::fl::compression::CompressFeatureMap> compress_feature_maps; | |||
| for (size_t i = 0; i < fbs_compress_feature_map->size(); ++i) { | |||
| mindspore::fl::compression::CompressFeatureMap compress_feature_map; | |||
| int8_t *compress_weight_data = const_cast<int8_t *>(fbs_compress_feature_map->Get(i)->compress_data()->data()); | |||
| size_t compress_weight_size = fbs_compress_feature_map->Get(i)->compress_data()->size(); | |||
| MS_LOG(INFO) << "The compress weight size: " << compress_weight_size; | |||
| for (size_t j = 0; j < compress_weight_size; ++j) { | |||
| compress_feature_map.compress_data.emplace_back(compress_weight_data[j]); | |||
| } | |||
| compress_feature_map.min_val = fbs_compress_feature_map->Get(i)->min_val(); | |||
| compress_feature_map.max_val = fbs_compress_feature_map->Get(i)->max_val(); | |||
| MS_LOG(INFO) << "Min value: " << compress_feature_map.min_val; | |||
| MS_LOG(INFO) << "Max value: " << compress_feature_map.max_val; | |||
| compress_feature_maps.emplace_back(compress_feature_map); | |||
| } | |||
| // Decode. | |||
| bool status = mindspore::fl::compression::DecodeExecutor::GetInstance().Decode( | |||
| weight_map, compress_feature_maps, upload_compress_type, upload_sparse_rate, seed, name_vec, data_size); | |||
| if (status) { | |||
| for (size_t i = 0; i < name_vec.size(); ++i) { | |||
| std::string weight_full_name = name_vec[i]; | |||
| size_t weight_size = (*weight_map)[weight_full_name].size() * sizeof(float); | |||
| UploadData upload_data; | |||
| upload_data[kNewWeight].addr = (*weight_map)[weight_full_name].data(); | |||
| upload_data[kNewWeight].size = weight_size; | |||
| feature_map[weight_full_name] = upload_data; | |||
| } | |||
| return feature_map; | |||
| } | |||
| MS_LOG(WARNING) << "Decode failed!"; | |||
| return feature_map; | |||
| } | |||
| ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) { | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) { | |||
| @@ -30,6 +30,9 @@ | |||
| #ifdef ENABLE_ARMOUR | |||
| #include "fl/armour/cipher/cipher_meta_storage.h" | |||
| #endif | |||
| #include "fl/compression/decode_executor.h" | |||
| #include "schema/fl_job_generated.h" | |||
| #include "schema/cipher_generated.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| @@ -64,8 +67,12 @@ class UpdateModelKernel : public RoundKernel { | |||
| std::map<std::string, UploadData> ParseSignDSFeatureMap(const schema::RequestUpdateModel *update_model_req, | |||
| size_t data_size, | |||
| std::map<std::string, std::vector<float>> *weight_map); | |||
| std::map<std::string, UploadData> ParseUploadCompressFeatureMap( | |||
| const schema::RequestUpdateModel *update_model_req, size_t data_size, | |||
| std::map<std::string, std::vector<float>> *weight_map); | |||
| bool VerifySignDSFeatureMap(const std::unordered_map<std::string, size_t> &model, | |||
| const schema::RequestUpdateModel *update_model_req); | |||
| bool VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req); | |||
| ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req); | |||
| sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req); | |||
| @@ -78,6 +85,11 @@ class UpdateModelKernel : public RoundKernel { | |||
| // The time window of one iteration. | |||
| size_t iteration_time_window_{0}; | |||
| // Decode functions of compression. | |||
| std::map<std::string, UploadData> DecodeFeatureMap(std::map<std::string, std::vector<float>> *weight_map, | |||
| const schema::RequestUpdateModel *update_model_req, | |||
| schema::CompressType upload_compress_type, size_t data_size); | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| @@ -44,6 +44,11 @@ void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(array); | |||
| char_arrays_.push_back(std::move(*array)); | |||
| } | |||
| void MemoryRegister::StoreFloat32(std::unique_ptr<float> *param) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(param); | |||
| float_params_.push_back(std::move(*param)); | |||
| } | |||
| } // namespace server | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,7 @@ | |||
| #include <utility> | |||
| #include <typeinfo> | |||
| #include "fl/server/common.h" | |||
| #include "fl/compression/encode_executor.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| @@ -70,6 +71,25 @@ class MemoryRegister { | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void RegisterParameter(const std::string &name, std::unique_ptr<T> *param, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| void *data = param->get(); | |||
| AddressPtr addressPtr = std::make_shared<Address>(); | |||
| addressPtr->addr = data; | |||
| addressPtr->size = size; | |||
| if (typeid(T) == typeid(float)) { | |||
| auto float_param = CastUniqueParamPtr<float, T>(param); | |||
| StoreFloat32(&float_param); | |||
| } else { | |||
| MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); | |||
| return; | |||
| } | |||
| RegisterAddressPtr(name, addressPtr); | |||
| return; | |||
| } | |||
| private: | |||
| std::map<std::string, AddressPtr> addresses_; | |||
| std::vector<std::unique_ptr<float[]>> float_arrays_; | |||
| @@ -86,6 +106,15 @@ class MemoryRegister { | |||
| std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) { | |||
| return std::unique_ptr<T[]>{reinterpret_cast<T *>(array->release())}; | |||
| } | |||
| std::vector<std::unique_ptr<float>> float_params_; | |||
| void StoreFloat32(std::unique_ptr<float> *array); | |||
| template <typename T, typename S> | |||
| std::unique_ptr<T> CastUniqueParamPtr(std::unique_ptr<S> *param) { | |||
| return std::unique_ptr<T>{reinterpret_cast<T *>(param->release())}; | |||
| } | |||
| }; | |||
| } // namespace server | |||
| } // namespace fl | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "fl/server/executor.h" | |||
| #include "pipeline/jit/parse/parse.h" | |||
| #include "include/common/utils/python_adapter.h" | |||
| namespace mindspore { | |||
| @@ -33,6 +34,10 @@ void ModelStore::Initialize(uint32_t rank_id, uint32_t max_count) { | |||
| max_model_count_ = max_count; | |||
| initial_model_ = AssignNewModelMemory(); | |||
| iteration_to_model_[kInitIterationNum] = initial_model_; | |||
| std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel(); | |||
| for (const auto &item : mindspore::fl::compression::kCompressTypeMap) { | |||
| iteration_to_compress_model_[kInitIterationNum][item.first] = AssignNewCompressModelMemory(item.first, model); | |||
| } | |||
| model_size_ = ComputeModelSize(); | |||
| MS_LOG(INFO) << "Model store checkpoint dir is: " << ps::PSContext::instance()->checkpoint_dir(); | |||
| } | |||
| @@ -101,6 +106,24 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration | |||
| return model; | |||
| } | |||
| std::map<std::string, AddressPtr> ModelStore::GetCompressModelByIterNum(size_t iteration, | |||
| schema::CompressType compressType) { | |||
| std::unique_lock<std::mutex> lock(model_mtx_); | |||
| std::map<std::string, AddressPtr> compressModel = {}; | |||
| if (iteration_to_compress_model_.count(iteration) == 0) { | |||
| MS_LOG(ERROR) << "Compress Model for iteration " << iteration << " is not stored."; | |||
| return compressModel; | |||
| } | |||
| std::map<schema::CompressType, std::shared_ptr<MemoryRegister>> compress_model_map = | |||
| iteration_to_compress_model_[iteration]; | |||
| if (compress_model_map.count(compressType) == 0) { | |||
| MS_LOG(ERROR) << "Compress Model for compress type " << compressType << " is not stored."; | |||
| return compressModel; | |||
| } | |||
| compressModel = iteration_to_compress_model_[iteration][compressType]->addresses(); | |||
| return compressModel; | |||
| } | |||
| void ModelStore::Reset() { | |||
| std::unique_lock<std::mutex> lock(model_mtx_); | |||
| initial_model_ = iteration_to_model_.rbegin()->second; | |||
| @@ -114,6 +137,11 @@ const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_t | |||
| return iteration_to_model_; | |||
| } | |||
| const std::map<size_t, CompressTypeMap> &ModelStore::iteration_to_compress_model() { | |||
| std::unique_lock<std::mutex> lock(model_mtx_); | |||
| return iteration_to_compress_model_; | |||
| } | |||
| size_t ModelStore::model_size() const { return model_size_; } | |||
| std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() { | |||
| @@ -146,6 +174,86 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() { | |||
| return memory_register; | |||
| } | |||
| std::shared_ptr<MemoryRegister> ModelStore::AssignNewCompressModelMemory( | |||
| schema::CompressType compressType, const std::map<std::string, AddressPtr> &model) { | |||
| if (model.empty()) { | |||
| MS_LOG(EXCEPTION) << "Model feature map is empty."; | |||
| return nullptr; | |||
| } | |||
| std::map<string, std::vector<float>> feature_maps; | |||
| for (auto &feature_map : model) { | |||
| auto weight_fullname = feature_map.first; | |||
| auto weight_data = reinterpret_cast<float *>(feature_map.second->addr); | |||
| std::vector<float> weight_data_vector{weight_data, weight_data + feature_map.second->size / sizeof(float)}; | |||
| feature_maps[weight_fullname] = weight_data_vector; | |||
| } | |||
| std::map<std::string, mindspore::fl::compression::CompressWeight> compressWeights; | |||
| bool status = mindspore::fl::compression::CompressExecutor::GetInstance().construct_compress_weight( | |||
| &compressWeights, feature_maps, compressType); | |||
| if (!status) { | |||
| MS_LOG(ERROR) << "Encode failed!"; | |||
| return nullptr; | |||
| } | |||
| // Assign new memory for the compress model. | |||
| std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>(); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr); | |||
| MS_LOG(INFO) << "Register compressWeight for compressType: " << schema::EnumNameCompressType(compressType); | |||
| for (const auto &compressWeight : compressWeights) { | |||
| if (compressType == schema::CompressType_QUANT) { | |||
| std::string compress_weight_name = compressWeight.first; | |||
| std::string min_val_name = compress_weight_name + "." + kMinVal; | |||
| std::string max_val_name = compress_weight_name + "." + kMaxVal; | |||
| size_t compress_weight_size = compressWeight.second.compress_data_len * sizeof(int8_t); | |||
| auto compress_weight_data = std::make_unique<char[]>(compress_weight_size); | |||
| auto src_data_size = compress_weight_size; | |||
| auto dst_data_size = compress_weight_size; | |||
| int ret = | |||
| memcpy_s(compress_weight_data.get(), dst_data_size, compressWeight.second.compress_data.data(), src_data_size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return nullptr; | |||
| } | |||
| memory_register->RegisterArray(compress_weight_name, &compress_weight_data, compress_weight_size); | |||
| size_t float_size = 1; | |||
| auto min_val_ptr = std::make_unique<float>(compressWeight.second.min_val); | |||
| auto max_val_ptr = std::make_unique<float>(compressWeight.second.max_val); | |||
| memory_register->RegisterParameter(min_val_name, &min_val_ptr, float_size); | |||
| memory_register->RegisterParameter(max_val_name, &max_val_ptr, float_size); | |||
| } | |||
| } | |||
| return memory_register; | |||
| } | |||
| void ModelStore::StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) { | |||
| std::unique_lock<std::mutex> lock(model_mtx_); | |||
| if (iteration_to_compress_model_.count(iteration) != 0) { | |||
| MS_LOG(WARNING) << "Compress Model for iteration " << iteration << " is already stored"; | |||
| return; | |||
| } | |||
| if (new_model.empty()) { | |||
| MS_LOG(ERROR) << "Compress Model feature map is empty."; | |||
| return; | |||
| } | |||
| iteration_to_compress_model_[iteration] = {}; | |||
| if (iteration_to_compress_model_.size() >= max_model_count_) { | |||
| auto compress_model_map = iteration_to_compress_model_.begin()->second; | |||
| compress_model_map.clear(); | |||
| (void)iteration_to_compress_model_.erase(iteration_to_compress_model_.begin()); | |||
| } | |||
| for (const auto &item : mindspore::fl::compression::kCompressTypeMap) { | |||
| auto memory_register = AssignNewCompressModelMemory(item.first, new_model); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(memory_register); | |||
| iteration_to_compress_model_[iteration][item.first] = memory_register; | |||
| } | |||
| return; | |||
| } | |||
| size_t ModelStore::ComputeModelSize() { | |||
| std::unique_lock<std::mutex> lock(model_mtx_); | |||
| if (iteration_to_model_.empty()) { | |||
| @@ -179,13 +287,15 @@ void ModelStore::RelModelResponseCache(const void *data, size_t datalen, void *e | |||
| std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const string &round_name, | |||
| size_t cur_iteration_num, | |||
| size_t model_iteration_num) { | |||
| size_t model_iteration_num, | |||
| const std::string &compress_type) { | |||
| std::unique_lock<std::mutex> lock(model_response_cache_lock_); | |||
| auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(), | |||
| [&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) { | |||
| return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && | |||
| item.model_iteration_num == model_iteration_num; | |||
| }); | |||
| auto it = std::find_if( | |||
| model_response_cache_.begin(), model_response_cache_.end(), | |||
| [&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) { | |||
| return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && | |||
| item.model_iteration_num == model_iteration_num && item.compress_type == compress_type; | |||
| }); | |||
| if (it == model_response_cache_.end()) { | |||
| return nullptr; | |||
| } | |||
| @@ -196,14 +306,16 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const st | |||
| std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const string &round_name, | |||
| size_t cur_iteration_num, | |||
| size_t model_iteration_num, const void *data, | |||
| size_t datalen) { | |||
| size_t model_iteration_num, | |||
| const std::string &compress_type, | |||
| const void *data, size_t datalen) { | |||
| std::unique_lock<std::mutex> lock(model_response_cache_lock_); | |||
| auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(), | |||
| [&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) { | |||
| return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && | |||
| item.model_iteration_num == model_iteration_num; | |||
| }); | |||
| auto it = std::find_if( | |||
| model_response_cache_.begin(), model_response_cache_.end(), | |||
| [&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) { | |||
| return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && | |||
| item.model_iteration_num == model_iteration_num && item.compress_type == compress_type; | |||
| }); | |||
| if (it != model_response_cache_.end()) { | |||
| it->reference_count += 1; | |||
| total_add_reference_count += 1; | |||
| @@ -223,6 +335,7 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const | |||
| item.round_name = round_name; | |||
| item.cur_iteration_num = cur_iteration_num; | |||
| item.model_iteration_num = model_iteration_num; | |||
| item.compress_type = compress_type; | |||
| item.cache = cache; | |||
| item.reference_count = 1; | |||
| total_add_reference_count += 1; | |||
| @@ -25,6 +25,7 @@ | |||
| #include "fl/server/common.h" | |||
| #include "fl/server/memory_register.h" | |||
| #include "fl/server/executor.h" | |||
| #include "fl/compression/encode_executor.h" | |||
| #include "fl/server/local_meta_store.h" | |||
| namespace mindspore { | |||
| @@ -36,6 +37,9 @@ constexpr size_t kInitIterationNum = 0; | |||
| // The initial iteration number after ModelStore is reset. | |||
| constexpr size_t kResetInitialIterNum = 1; | |||
| // The compress type map. | |||
| using CompressTypeMap = std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>; | |||
| // Server framework use ModelStore to store and query models. | |||
| // ModelStore stores multiple models because worker could get models of the previous iterations. | |||
| class ModelStore { | |||
| @@ -64,15 +68,25 @@ class ModelStore { | |||
| // Returns the model size, which could be calculated at the initializing phase. | |||
| size_t model_size() const; | |||
| // Get compress model of the given iteration. | |||
| std::map<std::string, AddressPtr> GetCompressModelByIterNum(size_t iteration, schema::CompressType compressType); | |||
| const std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>> | |||
| &iteration_to_compress_model(); | |||
| void StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model); | |||
| static void RelModelResponseCache(const void *data, size_t datalen, void *extra); | |||
| std::shared_ptr<std::vector<uint8_t>> GetModelResponseCache(const std::string &round_name, size_t cur_iteration_num, | |||
| size_t model_iteration_num); | |||
| size_t model_iteration_num, | |||
| const std::string &compress_type); | |||
| std::shared_ptr<std::vector<uint8_t>> StoreModelResponseCache(const std::string &round_name, size_t cur_iteration_num, | |||
| size_t model_iteration_num, const void *data, | |||
| size_t model_iteration_num, | |||
| const std::string &compress_type, const void *data, | |||
| size_t datalen); | |||
| private: | |||
| ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {} | |||
| ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}), iteration_to_compress_model_({}) {} | |||
| ~ModelStore() = default; | |||
| ModelStore(const ModelStore &) = delete; | |||
| ModelStore &operator=(const ModelStore &) = delete; | |||
| @@ -83,6 +97,9 @@ class ModelStore { | |||
| // model_size_. | |||
| std::shared_ptr<MemoryRegister> AssignNewModelMemory(); | |||
| std::shared_ptr<MemoryRegister> AssignNewCompressModelMemory(schema::CompressType compressType, | |||
| const std::map<std::string, AddressPtr> &model); | |||
| // Calculate the model size. This method should be called after iteration_to_model_ is initialized. | |||
| size_t ComputeModelSize(); | |||
| @@ -95,12 +112,17 @@ class ModelStore { | |||
| // The number of all models stored is max_model_count_. | |||
| std::mutex model_mtx_; | |||
| std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_; | |||
| // iteration -> (compress type -> compress model) | |||
| std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>> iteration_to_compress_model_; | |||
| uint32_t rank_id_; | |||
| struct HttpResponseModelCache { | |||
| std::string round_name; // startFlJob, getModel | |||
| size_t cur_iteration_num = 0; | |||
| size_t model_iteration_num = 0; | |||
| std::string compress_type = kNoCompress; | |||
| size_t reference_count = 0; | |||
| std::shared_ptr<std::vector<uint8_t>> cache = nullptr; | |||
| }; | |||
| @@ -507,6 +507,12 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window, | |||
| "Set global iteration time window.") | |||
| .def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.") | |||
| .def("set_upload_compress_type", &PSContext::set_upload_compress_type, "Set upload compress type.") | |||
| .def("upload_compress_type", &PSContext::upload_compress_type, "Get upload compress type.") | |||
| .def("set_upload_sparse_rate", &PSContext::set_upload_sparse_rate, "Set upload sparse rate.") | |||
| .def("upload_sparse_rate", &PSContext::upload_sparse_rate, "Get upload sparse rate.") | |||
| .def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.") | |||
| .def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.") | |||
| .def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.") | |||
| .def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory."); | |||
| (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); | |||
| @@ -550,6 +550,19 @@ void PSContext::set_global_iteration_time_window(const uint64_t &global_iteratio | |||
| uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; } | |||
| void PSContext::set_upload_compress_type(const std::string &upload_compress_type) { | |||
| upload_compress_type_ = upload_compress_type; | |||
| } | |||
| std::string PSContext::upload_compress_type() const { return upload_compress_type_; } | |||
| void PSContext::set_upload_sparse_rate(float upload_sparse_rate) { upload_sparse_rate_ = upload_sparse_rate; } | |||
| float PSContext::upload_sparse_rate() const { return upload_sparse_rate_; } | |||
| void PSContext::set_download_compress_type(const std::string &download_compress_type) { | |||
| download_compress_type_ = download_compress_type; | |||
| } | |||
| std::string PSContext::download_compress_type() const { return download_compress_type_; } | |||
| std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; } | |||
| void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; } | |||
| @@ -40,6 +40,7 @@ constexpr char kPWEncryptType[] = "PW_ENCRYPT"; | |||
| constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT"; | |||
| constexpr char kNotEncryptType[] = "NOT_ENCRYPT"; | |||
| constexpr char kDSEncryptType[] = "SIGNDS"; | |||
| constexpr char kNoCompressType[] = "NO_COMPRESS"; | |||
| // Use binary data to represent federated learning server's context so that we can judge which round resets the | |||
| // iteration. From right to left, each bit stands for: | |||
| @@ -230,6 +231,15 @@ class PSContext { | |||
| void set_global_iteration_time_window(const uint64_t &global_iteration_time_window); | |||
| uint64_t global_iteration_time_window() const; | |||
| void set_upload_compress_type(const std::string &upload_compress_type); | |||
| std::string upload_compress_type() const; | |||
| void set_upload_sparse_rate(float upload_sparse_rate); | |||
| float upload_sparse_rate() const; | |||
| void set_download_compress_type(const std::string &download_compress_type); | |||
| std::string download_compress_type() const; | |||
| std::string checkpoint_dir() const; | |||
| void set_checkpoint_dir(const std::string &checkpoint_dir); | |||
| @@ -286,6 +296,9 @@ class PSContext { | |||
| server_password_(""), | |||
| http_url_prefix_(""), | |||
| global_iteration_time_window_(3600000), | |||
| upload_compress_type_(kNoCompressType), | |||
| upload_sparse_rate_(0.4f), | |||
| download_compress_type_(kNoCompressType), | |||
| checkpoint_dir_("") {} | |||
| bool ps_enabled_; | |||
| bool is_worker_; | |||
| @@ -419,6 +432,13 @@ class PSContext { | |||
| // The time window of startFLJob round in millisecond. | |||
| uint64_t global_iteration_time_window_; | |||
| // Hyper parameters for upload compression. | |||
| std::string upload_compress_type_; | |||
| float upload_sparse_rate_; | |||
| // Hyper parameters for download compression. | |||
| std::string download_compress_type_; | |||
| // directory of server checkpoint | |||
| std::string checkpoint_dir_; | |||
| }; | |||
| @@ -105,6 +105,16 @@ public class FLLiteClient { | |||
| batchSize = flPlan.miniBatch(); | |||
| String serverMod = flPlan.serverMode(); | |||
| localFLParameter.setServerMod(serverMod); | |||
| // Get and set hyper parameters for compression. | |||
| byte uploadCompressType = flJob.uploadCompressType(); | |||
| LOGGER.info(Common.addTag("[startFLJob] [compression] uploadCompressType: " + uploadCompressType)); | |||
| localFLParameter.setUploadCompressType(uploadCompressType); | |||
| float uploadSparseRate = flJob.uploadSparseRate(); | |||
| LOGGER.info(Common.addTag("[startFLJob] [compression] uploadSparseRate: " + uploadSparseRate)); | |||
| localFLParameter.setUploadSparseRatio(uploadSparseRate); | |||
| int seed = flJob.iteration(); | |||
| LOGGER.info(Common.addTag("[startFLJob] [compression] seed: " + seed)); | |||
| localFLParameter.setSeed(seed); | |||
| if (Common.checkFLName(flParameter.getFlName())) { | |||
| deprecatedSetBatchSize(batchSize); | |||
| } else { | |||
| @@ -446,7 +456,7 @@ public class FLLiteClient { | |||
| return status; | |||
| } | |||
| private Map<String, float[]> getFeatureMap() { | |||
| public Map<String, float[]> getFeatureMap() { | |||
| Map<String, float[]> featureMap = new HashMap<>(); | |||
| if (Common.checkFLName(flParameter.getFlName())) { | |||
| featureMap = deprecatedGetFeatureMap(); | |||
| @@ -530,8 +540,7 @@ public class FLLiteClient { | |||
| localFLParameter.getEncryptLevel().toString() + "> : " + curStatus)); | |||
| return curStatus; | |||
| case DP_ENCRYPT: | |||
| // get the feature map before train | |||
| oldFeatureMap = getFeatureMap(); | |||
| oldFeatureMap = localFLParameter.getOldFeatureMap(); | |||
| curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap); | |||
| retCode = ResponseCode.SUCCEED; | |||
| if (curStatus != FLClientStatus.SUCCESS) { | |||
| @@ -542,8 +551,7 @@ public class FLLiteClient { | |||
| LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!")); | |||
| return FLClientStatus.SUCCESS; | |||
| case SIGNDS: | |||
| // get the feature map before train | |||
| oldFeatureMap = getFeatureMap(); | |||
| oldFeatureMap = localFLParameter.getOldFeatureMap(); | |||
| curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap); | |||
| retCode = ResponseCode.SUCCEED; | |||
| if (curStatus != FLClientStatus.SUCCESS) { | |||
| @@ -18,7 +18,9 @@ package com.mindspore.flclient; | |||
| import static com.mindspore.flclient.LocalFLParameter.ALBERT; | |||
| import com.mindspore.flclient.compression.CompressMode; | |||
| import com.mindspore.flclient.model.RunType; | |||
| import mindspore.schema.CompressType; | |||
| import java.util.ArrayList; | |||
| import java.util.Arrays; | |||
| @@ -603,6 +605,16 @@ public class FLParameter { | |||
| this.batchSize = batchSize; | |||
| } | |||
| public byte[] getDownloadCompressTypes() { | |||
| byte[] downloadCompressTypes = new byte[CompressMode.COMPRESS_TYPE_MAP.size()]; | |||
| int index = 0; | |||
| for (byte downloadCompressType : CompressMode.COMPRESS_TYPE_MAP.keySet()) { | |||
| downloadCompressTypes[index] = downloadCompressType; | |||
| index += 1; | |||
| } | |||
| return downloadCompressTypes; | |||
| } | |||
| public int[][] getInputShape() { | |||
| return inputShape; | |||
| } | |||
| @@ -18,6 +18,7 @@ package com.mindspore.flclient; | |||
| import com.google.flatbuffers.FlatBufferBuilder; | |||
| import com.mindspore.flclient.compression.DecodeExecutor; | |||
| import com.mindspore.flclient.model.AlInferBert; | |||
| import com.mindspore.flclient.model.AlTrainBert; | |||
| import com.mindspore.flclient.model.Client; | |||
| @@ -27,11 +28,9 @@ import com.mindspore.flclient.model.SessionUtil; | |||
| import com.mindspore.flclient.model.Status; | |||
| import com.mindspore.flclient.model.TrainLenet; | |||
| import mindspore.schema.FeatureMap; | |||
| import mindspore.schema.RequestGetModel; | |||
| import mindspore.schema.ResponseCode; | |||
| import mindspore.schema.ResponseGetModel; | |||
| import mindspore.schema.*; | |||
| import java.util.List; | |||
| import java.util.ArrayList; | |||
| import java.util.Date; | |||
| import java.util.logging.Logger; | |||
| @@ -94,7 +93,8 @@ public class GetModel { | |||
| throw new IllegalArgumentException(); | |||
| } | |||
| RequestGetModelBuilder builder = new RequestGetModelBuilder(); | |||
| return builder.iteration(iteration).flName(name).time().build(); | |||
| return builder.iteration(iteration).flName(name).time() | |||
| .downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()).build(); | |||
| } | |||
| private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) { | |||
| @@ -226,11 +226,29 @@ public class GetModel { | |||
| return status; | |||
| } | |||
| private List<FeatureMap> parseFeatureMapList(ResponseGetModel responseDataBuf) { | |||
| List<FeatureMap> featureMaps; | |||
| byte compressType = responseDataBuf.downloadCompressType(); | |||
| if (responseDataBuf.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) { | |||
| featureMaps = new ArrayList<>(); | |||
| for (int i = 0; i < responseDataBuf.featureMapLength(); i++) { | |||
| featureMaps.add(responseDataBuf.featureMap(i)); | |||
| } | |||
| } else { | |||
| List<mindspore.schema.CompressFeatureMap> compressFeatureMapList = new ArrayList<>(); | |||
| for (int i = 0; i < responseDataBuf.compressFeatureMapLength(); i++) { | |||
| compressFeatureMapList.add(responseDataBuf.compressFeatureMap(i)); | |||
| } | |||
| featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); | |||
| } | |||
| return featureMaps; | |||
| } | |||
| private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) { | |||
| FLClientStatus status; | |||
| Client client = ClientManager.getClient(flParameter.getFlName()); | |||
| int fmCount = responseDataBuf.featureMapLength(); | |||
| if (fmCount <= 0) { | |||
| List<FeatureMap> featureMapList = parseFeatureMapList(responseDataBuf); | |||
| if (featureMapList.size() <= 0) { | |||
| LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero")); | |||
| retCode = ResponseCode.SystemError; | |||
| return FLClientStatus.FAILED; | |||
| @@ -239,8 +257,8 @@ public class GetModel { | |||
| LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod())); | |||
| ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>(); | |||
| ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>(); | |||
| for (int i = 0; i < fmCount; i++) { | |||
| FeatureMap feature = responseDataBuf.featureMap(i); | |||
| for (int i = 0; i < featureMapList.size(); i++) { | |||
| FeatureMap feature = featureMapList.get(i); | |||
| if (feature == null) { | |||
| LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null")); | |||
| retCode = ResponseCode.SystemError; | |||
| @@ -289,8 +307,8 @@ public class GetModel { | |||
| } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { | |||
| LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod())); | |||
| ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>(); | |||
| for (int i = 0; i < fmCount; i++) { | |||
| FeatureMap feature = responseDataBuf.featureMap(i); | |||
| for (int i = 0; i < featureMapList.size(); i++) { | |||
| FeatureMap feature = featureMapList.get(i); | |||
| if (feature == null) { | |||
| LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null")); | |||
| retCode = ResponseCode.SystemError; | |||
| @@ -365,6 +383,7 @@ public class GetModel { | |||
| private int nameOffset = 0; | |||
| private int iteration = 0; | |||
| private int timeStampOffset = 0; | |||
| private int downloadCompressTypesOffset = 0; | |||
| public RequestGetModelBuilder() { | |||
| builder = new FlatBufferBuilder(); | |||
| @@ -392,11 +411,23 @@ public class GetModel { | |||
| return this; | |||
| } | |||
| private RequestGetModelBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) { | |||
| if (downloadCompressTypes == null || downloadCompressTypes.length == 0) { | |||
| LOGGER.severe(Common.addTag("[GetModel] the parameter of <downloadCompressTypes> is null or empty," + | |||
| " please check!")); | |||
| throw new IllegalArgumentException(); | |||
| } | |||
| this.downloadCompressTypesOffset = RequestGetModel.createDownloadCompressTypesVector(builder, | |||
| downloadCompressTypes); | |||
| return this; | |||
| } | |||
| private byte[] build() { | |||
| RequestGetModel.startRequestGetModel(builder); | |||
| RequestGetModel.addFlName(builder, nameOffset); | |||
| RequestGetModel.addIteration(builder, iteration); | |||
| RequestGetModel.addTimestamp(builder, timeStampOffset); | |||
| RequestGetModel.addDownloadCompressTypes(builder, downloadCompressTypesOffset); | |||
| int root = RequestGetModel.endRequestGetModel(builder); | |||
| builder.finish(root); | |||
| return builder.sizedByteArray(); | |||
| @@ -22,6 +22,7 @@ import org.bouncycastle.math.ec.rfc7748.X25519; | |||
| import java.util.ArrayList; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.logging.Logger; | |||
| /** | |||
| @@ -83,6 +84,10 @@ public class LocalFLParameter { | |||
| private MSConfig msConfig = new MSConfig(); | |||
| private boolean useSSL = true; | |||
| private float lr = 0.1f; | |||
| private Map<String, float[]> oldFeatureMap; | |||
| private byte uploadCompressType = 0; | |||
| private int seed = 0; | |||
| private float uploadSparseRatio = 0.08f; | |||
| private LocalFLParameter() { | |||
| @@ -250,4 +255,36 @@ public class LocalFLParameter { | |||
| public void setLr(float lr) { | |||
| this.lr = lr; | |||
| } | |||
| public Map<String, float[]> getOldFeatureMap() { | |||
| return oldFeatureMap; | |||
| } | |||
| public void setOldFeatureMap(Map<String, float[]> oldFeatureMap) { | |||
| this.oldFeatureMap = oldFeatureMap; | |||
| } | |||
| public byte getUploadCompressType() { | |||
| return uploadCompressType; | |||
| } | |||
| public void setUploadCompressType(byte uploadCompressType) { | |||
| this.uploadCompressType = uploadCompressType; | |||
| } | |||
| public int getSeed() { | |||
| return seed; | |||
| } | |||
| public void setSeed(int seed) { | |||
| this.seed = seed; | |||
| } | |||
| public float getUploadSparseRatio() { | |||
| return uploadSparseRatio; | |||
| } | |||
| public void setUploadSparseRatio(float uploadSparseRatio) { | |||
| this.uploadSparseRatio = uploadSparseRatio; | |||
| } | |||
| } | |||
| @@ -208,35 +208,34 @@ public class SecureProtocol { | |||
| * @param trainDataSize trainDataSize tne size of train data set. | |||
| * @return the serialized model weights after adding masks. | |||
| */ | |||
| public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) { | |||
| public Map<String, List<Float>> pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, | |||
| float[]> trainedMap) { | |||
| Map<String, List<Float>> featureMaps = new HashMap<>(); | |||
| if (featureMask == null || featureMask.length == 0) { | |||
| LOGGER.severe("[Encrypt] feature mask is null, please check"); | |||
| return new int[0]; | |||
| return new HashMap<>(); | |||
| } | |||
| LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length)); | |||
| int featureSize = updateFeatureName.size(); | |||
| int[] featuresMap = new int[featureSize]; | |||
| int maskIndex = 0; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = updateFeatureName.get(i); | |||
| float[] data = trainedMap.get(key); | |||
| List<Float> featureMap = new ArrayList<>(); | |||
| LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length)); | |||
| for (int j = 0; j < data.length; j++) { | |||
| float rawData = data[j]; | |||
| if (maskIndex >= featureMask.length) { | |||
| LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check"); | |||
| return new int[0]; | |||
| return new HashMap<>(); | |||
| } | |||
| float maskData = rawData * trainDataSize + featureMask[maskIndex]; | |||
| maskIndex += 1; | |||
| data[j] = maskData; | |||
| featureMap.add(maskData); | |||
| } | |||
| int featureName = builder.createString(key); | |||
| int weight = FeatureMap.createDataVector(builder, data); | |||
| int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); | |||
| featuresMap[i] = featureMap; | |||
| featureMaps.put(key, featureMap); | |||
| } | |||
| return featuresMap; | |||
| return featureMaps; | |||
| } | |||
| /** | |||
| @@ -365,7 +364,9 @@ public class SecureProtocol { | |||
| * @param trainDataSize tne size of train data set. | |||
| * @return the serialized model weights after adding masks. | |||
| */ | |||
| public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) { | |||
| public Map<String, List<Float>> dpMaskModel(FlatBufferBuilder builder, int trainDataSize, | |||
| Map<String, float[]> trainedMap) { | |||
| Map<String, List<Float>> featureMaps = new HashMap<>(); | |||
| // get feature map | |||
| Map<String, float[]> mapBeforeTrain = modelMap; | |||
| int featureSize = updateFeatureName.size(); | |||
| @@ -383,7 +384,7 @@ public class SecureProtocol { | |||
| float rawData = data[j]; | |||
| if (j >= dataBeforeTrain.length) { | |||
| LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check"); | |||
| return new int[0]; | |||
| return new HashMap<>(); | |||
| } | |||
| float rawDataBeforeTrain = dataBeforeTrain[j]; | |||
| float updateData = rawData - rawDataBeforeTrain; | |||
| @@ -393,23 +394,23 @@ public class SecureProtocol { | |||
| updateL2Norm = Math.sqrt(updateL2Norm); | |||
| if (updateL2Norm == 0) { | |||
| LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check")); | |||
| return new int[0]; | |||
| return new HashMap<>(); | |||
| } | |||
| double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm); | |||
| // clip and add noise | |||
| int[] featuresMap = new int[featureSize]; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = updateFeatureName.get(i); | |||
| if (!trainedMap.containsKey(key)) { | |||
| LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!"); | |||
| return new int[0]; | |||
| return new HashMap<>(); | |||
| } | |||
| float[] data = trainedMap.get(key); | |||
| float[] data2 = new float[data.length]; | |||
| List<Float> featureMap = new ArrayList<>(); | |||
| if (!mapBeforeTrain.containsKey(key)) { | |||
| LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!"); | |||
| return new int[0]; | |||
| return new HashMap<>(); | |||
| } | |||
| float[] dataBeforeTrain = mapBeforeTrain.get(key); | |||
| @@ -419,7 +420,7 @@ public class SecureProtocol { | |||
| float rawData = data[j]; | |||
| if (j >= dataBeforeTrain.length) { | |||
| LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check"); | |||
| return new int[0]; | |||
| return new HashMap<>(); | |||
| } | |||
| float rawDataBeforeTrain = dataBeforeTrain[j]; | |||
| float updateData = rawData - rawDataBeforeTrain; | |||
| @@ -432,13 +433,11 @@ public class SecureProtocol { | |||
| updateData += gaussianNoise; | |||
| data2[j] = rawDataBeforeTrain + updateData; | |||
| data2[j] = data2[j] * trainDataSize; | |||
| featureMap.add(data2[j]); | |||
| } | |||
| int featureName = builder.createString(key); | |||
| int weight = FeatureMap.createDataVector(builder, data2); | |||
| int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); | |||
| featuresMap[i] = featureMap; | |||
| featureMaps.put(key, featureMap); | |||
| } | |||
| return featuresMap; | |||
| return featureMaps; | |||
| } | |||
| /** | |||
| @@ -18,6 +18,7 @@ package com.mindspore.flclient; | |||
| import com.google.flatbuffers.FlatBufferBuilder; | |||
| import com.mindspore.flclient.compression.DecodeExecutor; | |||
| import com.mindspore.flclient.model.AlInferBert; | |||
| import com.mindspore.flclient.model.AlTrainBert; | |||
| import com.mindspore.flclient.model.Client; | |||
| @@ -29,6 +30,7 @@ import com.mindspore.flclient.model.TrainLenet; | |||
| import com.mindspore.flclient.pki.PkiBean; | |||
| import com.mindspore.flclient.pki.PkiUtil; | |||
| import mindspore.schema.*; | |||
| import mindspore.schema.FLPlan; | |||
| import mindspore.schema.FeatureMap; | |||
| import mindspore.schema.RequestFLJob; | |||
| @@ -38,6 +40,7 @@ import mindspore.schema.ResponseFLJob; | |||
| import java.io.IOException; | |||
| import java.security.cert.Certificate; | |||
| import java.util.ArrayList; | |||
| import java.util.List; | |||
| import java.util.logging.Logger; | |||
| import static com.mindspore.flclient.LocalFLParameter.ALBERT; | |||
| @@ -119,6 +122,7 @@ public class StartFLJob { | |||
| .iteration(iteration) | |||
| .signData(pkiBean.getSignData()) | |||
| .certificateChain(pkiBean.getCertificates()) | |||
| .downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()) | |||
| .build(); | |||
| } | |||
| return builder.flName(flParameter.getFlName()) | |||
| @@ -126,6 +130,7 @@ public class StartFLJob { | |||
| .id(localFLParameter.getFlID()) | |||
| .dataSize(dataSize) | |||
| .iteration(iteration) | |||
| .downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()) | |||
| .build(); | |||
| } | |||
| @@ -151,8 +156,9 @@ public class StartFLJob { | |||
| ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>(); | |||
| ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>(); | |||
| featureSize = 0; | |||
| for (int i = 0; i < fmCount; i++) { | |||
| FeatureMap feature = flJob.featureMap(i); | |||
| List<FeatureMap> featureMapList = parseFeatureMapList(flJob); | |||
| for (int i = 0; i < featureMapList.size(); i++) { | |||
| FeatureMap feature = featureMapList.get(i); | |||
| if (feature == null) { | |||
| LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); | |||
| return FLClientStatus.FAILED; | |||
| @@ -233,12 +239,14 @@ public class StartFLJob { | |||
| private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) { | |||
| FLClientStatus status; | |||
| int fmCount = flJob.featureMapLength(); | |||
| ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>(); | |||
| updateFeatureName.clear(); | |||
| featureSize = 0; | |||
| for (int i = 0; i < fmCount; i++) { | |||
| FeatureMap feature = flJob.featureMap(i); | |||
| List<FeatureMap> featureMapList = parseFeatureMapList(flJob); | |||
| ArrayList<FeatureMap> featureMaps = new ArrayList<>(); | |||
| for (int i = 0; i < featureMapList.size(); i++) { | |||
| FeatureMap feature = featureMapList.get(i); | |||
| if (feature == null) { | |||
| LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); | |||
| return FLClientStatus.FAILED; | |||
| @@ -267,6 +275,24 @@ public class StartFLJob { | |||
| return FLClientStatus.SUCCESS; | |||
| } | |||
| private List<FeatureMap> parseFeatureMapList(ResponseFLJob flJob) { | |||
| List<FeatureMap> featureMaps; | |||
| byte compressType = flJob.downloadCompressType(); | |||
| if (flJob.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) { | |||
| LOGGER.info(Common.addTag("[parseFeatureMapList] create no compress feature map.")); | |||
| featureMaps = new ArrayList<>(); | |||
| for (int i = 0; i < flJob.featureMapLength(); i++) { | |||
| featureMaps.add(flJob.featureMap(i)); | |||
| } | |||
| } else { | |||
| List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>(); | |||
| for (int i = 0; i < flJob.compressFeatureMapLength(); i++) { | |||
| compressFeatureMapList.add(flJob.compressFeatureMap(i)); | |||
| } | |||
| featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); | |||
| } | |||
| return featureMaps; | |||
| } | |||
| private FLClientStatus hybridFeatures(ResponseFLJob flJob) { | |||
| FLClientStatus status; | |||
| @@ -275,8 +301,23 @@ public class StartFLJob { | |||
| ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>(); | |||
| ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>(); | |||
| featureSize = 0; | |||
| List<FeatureMap> featureMaps; | |||
| byte compressType = flJob.downloadCompressType(); | |||
| if (compressType == CompressType.NO_COMPRESS) { | |||
| featureMaps = new ArrayList<>(); | |||
| for (int i = 0; i < fmCount; i++) { | |||
| featureMaps.add(flJob.featureMap(i)); | |||
| } | |||
| } else { | |||
| List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>(); | |||
| for (int i = 0; i < flJob.compressFeatureMapLength(); i++) { | |||
| compressFeatureMapList.add(flJob.compressFeatureMap(i)); | |||
| } | |||
| featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); | |||
| fmCount = featureMaps.size(); | |||
| } | |||
| for (int i = 0; i < fmCount; i++) { | |||
| FeatureMap feature = flJob.featureMap(i); | |||
| FeatureMap feature = featureMaps.get(i); | |||
| if (feature == null) { | |||
| LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); | |||
| retCode = ResponseCode.SystemError; | |||
| @@ -335,8 +376,23 @@ public class StartFLJob { | |||
| int fmCount = flJob.featureMapLength(); | |||
| ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>(); | |||
| featureSize = 0; | |||
| byte compressType = flJob.downloadCompressType(); | |||
| List<FeatureMap> parseFeatureMaps; | |||
| if (compressType == CompressType.NO_COMPRESS) { | |||
| parseFeatureMaps = new ArrayList<>(); | |||
| for (int i = 0; i < fmCount; i++) { | |||
| parseFeatureMaps.add(flJob.featureMap(i)); | |||
| } | |||
| } else { | |||
| List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>(); | |||
| for (int i = 0; i < flJob.compressFeatureMapLength(); i++) { | |||
| compressFeatureMapList.add(flJob.compressFeatureMap(i)); | |||
| } | |||
| parseFeatureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); | |||
| fmCount = parseFeatureMaps.size(); | |||
| } | |||
| for (int i = 0; i < fmCount; i++) { | |||
| FeatureMap feature = flJob.featureMap(i); | |||
| FeatureMap feature = parseFeatureMaps.get(i); | |||
| if (feature == null) { | |||
| LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); | |||
| retCode = ResponseCode.SystemError; | |||
| @@ -437,8 +493,8 @@ public class StartFLJob { | |||
| switch (responseRetCode) { | |||
| case (ResponseCode.SUCCEED): | |||
| if (flJob.featureMapLength() <= 0) { | |||
| LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero")); | |||
| if (flJob.downloadCompressType() == CompressType.NO_COMPRESS && flJob.featureMapLength() <= 0) { | |||
| LOGGER.warning(Common.addTag("[startFLJob] the feature size get from server is zero")); | |||
| retCode = ResponseCode.SystemError; | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| @@ -484,6 +540,7 @@ public class StartFLJob { | |||
| private int equipCertOffset = 0; | |||
| private int equipCACertOffset = 0; | |||
| private int rootCertOffset = 0; | |||
| private int downloadCompressTypesOffset = 0; | |||
| public RequestStartFLJobBuilder() { | |||
| builder = new FlatBufferBuilder(); | |||
| @@ -598,6 +655,17 @@ public class StartFLJob { | |||
| return this; | |||
| } | |||
| private RequestStartFLJobBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) { | |||
| if (downloadCompressTypes == null || downloadCompressTypes.length == 0) { | |||
| LOGGER.severe(Common.addTag("[StartFLJob] the parameter of <downloadCompressTypes> is null or empty," + | |||
| " please check!")); | |||
| throw new IllegalArgumentException(); | |||
| } | |||
| this.downloadCompressTypesOffset = RequestFLJob.createDownloadCompressTypesVector(builder, | |||
| downloadCompressTypes); | |||
| return this; | |||
| } | |||
| /** | |||
| * build protobuffer | |||
| * | |||
| @@ -615,6 +683,7 @@ public class StartFLJob { | |||
| RequestFLJob.addEquipCaCert(builder, equipCACertOffset); | |||
| RequestFLJob.addEquipCert(builder, equipCertOffset); | |||
| RequestFLJob.addKeyAttestation(builder, keyAttestationOffset); | |||
| RequestFLJob.addDownloadCompressTypes(builder, downloadCompressTypesOffset); | |||
| int root = RequestFLJob.endRequestFLJob(builder); | |||
| builder.finish(root); | |||
| return builder.sizedByteArray(); | |||
| @@ -147,6 +147,10 @@ public class SyncFLJob { | |||
| LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration())); | |||
| updateTryTimePerIter(flLiteClient); | |||
| // Copy weights before training. | |||
| Map<String, float[]> oldFeatureMap = flLiteClient.getFeatureMap(); | |||
| localFLParameter.setOldFeatureMap(oldFeatureMap); | |||
| // create mask | |||
| curStatus = flLiteClient.getFeatureMask(); | |||
| if (curStatus == FLClientStatus.RESTART) { | |||
| @@ -26,11 +26,15 @@ import com.mindspore.flclient.model.SessionUtil; | |||
| import com.mindspore.flclient.model.Status; | |||
| import com.mindspore.flclient.model.TrainLenet; | |||
| import com.mindspore.lite.MSTensor; | |||
| import com.mindspore.flclient.compression.EncodeExecutor; | |||
| import com.mindspore.flclient.compression.CompressWeight; | |||
| import mindspore.schema.FeatureMap; | |||
| import mindspore.schema.CompressFeatureMap; | |||
| import mindspore.schema.RequestUpdateModel; | |||
| import mindspore.schema.ResponseCode; | |||
| import mindspore.schema.ResponseUpdateModel; | |||
| import static mindspore.schema.CompressType.NO_COMPRESS; | |||
| import java.util.ArrayList; | |||
| import java.util.Date; | |||
| @@ -208,6 +212,7 @@ public class UpdateModel { | |||
| private RequestUpdateModel requestUM; | |||
| private FlatBufferBuilder builder; | |||
| private int fmOffset = 0; | |||
| private int compFmOffset = 0; | |||
| private int nameOffset = 0; | |||
| private int idOffset = 0; | |||
| private int timestampOffset = 0; | |||
| @@ -215,8 +220,11 @@ public class UpdateModel { | |||
| private int sign = 0; | |||
| private int indexArrayOffset = 0; | |||
| private int iteration = 0; | |||
| private byte uploadCompressType = 0; | |||
| private float uploadSparseRate = 0.0f; | |||
| private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT; | |||
| private float uploadLossOffset = 0.0f; | |||
| private int nameVecOffset = 0; | |||
| private RequestUpdateModelBuilder(EncryptLevel encryptLevel) { | |||
| builder = new FlatBufferBuilder(); | |||
| @@ -294,34 +302,33 @@ public class UpdateModel { | |||
| } else { | |||
| trainedMap = getFeatureMap(); | |||
| } | |||
| Map<String, List<Float>> featureMaps = new HashMap<>(); | |||
| long startTime; | |||
| long endTime; | |||
| switch (encryptLevel) { | |||
| case PW_ENCRYPT: | |||
| int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap); | |||
| if (fmOffsetsPW == null || fmOffsetsPW.length == 0) { | |||
| LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " + | |||
| featureMaps = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap); | |||
| if (featureMaps == null || featureMaps.size() == 0) { | |||
| LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.pwMaskModel> is " + | |||
| "null, please check"); | |||
| throw new IllegalArgumentException(); | |||
| } | |||
| this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); | |||
| LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!")); | |||
| return this; | |||
| break; | |||
| case DP_ENCRYPT: | |||
| startTime = System.currentTimeMillis(); | |||
| int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap); | |||
| if (fmOffsetsDP == null || fmOffsetsDP.length == 0) { | |||
| LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " + | |||
| featureMaps = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap); | |||
| if (featureMaps == null || featureMaps.size() == 0) { | |||
| LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.dpMaskModel> is " + | |||
| "null, please check"); | |||
| retCode = ResponseCode.RequestError; | |||
| status = FLClientStatus.FAILED; | |||
| throw new IllegalArgumentException(); | |||
| } | |||
| this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP); | |||
| LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!")); | |||
| endTime = System.currentTimeMillis(); | |||
| LOGGER.info(Common.addTag("[Encrypt] dp time is: " + (endTime - startTime) + "ms")); | |||
| return this; | |||
| LOGGER.info(Common.addTag("dp time is " + (endTime - startTime) + "ms")); | |||
| break; | |||
| case SIGNDS: | |||
| startTime = System.currentTimeMillis(); | |||
| // signds alg return indexArray, and package indexArray into flatbuffer. | |||
| @@ -352,31 +359,104 @@ public class UpdateModel { | |||
| this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignds); | |||
| LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!")); | |||
| endTime = System.currentTimeMillis(); | |||
| LOGGER.info(Common.addTag("[Encrypt] signds time is: " + (endTime - startTime) + "ms")); | |||
| LOGGER.info(Common.addTag("signds time is " + (endTime - startTime) + "ms")); | |||
| return this; | |||
| case NOT_ENCRYPT: | |||
| default: | |||
| startTime = System.currentTimeMillis(); | |||
| int featureSize = updateFeatureName.size(); | |||
| int[] fmOffsets = new int[featureSize]; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = updateFeatureName.get(i); | |||
| float[] data = trainedMap.get(key); | |||
| LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " + | |||
| "size: " + data.length)); | |||
| for (int j = 0; j < data.length; j++) { | |||
| data[j] = data[j] * trainDataSize; | |||
| for (String name : updateFeatureName) { | |||
| float[] data = trainedMap.get(name); | |||
| List<Float> featureMap = new ArrayList<>(); | |||
| for (float datum : data) { | |||
| featureMap.add(datum * (float) trainDataSize); | |||
| } | |||
| int featureName = builder.createString(key); | |||
| int weight = FeatureMap.createDataVector(builder, data); | |||
| int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); | |||
| fmOffsets[i] = featureMap; | |||
| featureMaps.put(name, featureMap); | |||
| } | |||
| this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets); | |||
| endTime = System.currentTimeMillis(); | |||
| LOGGER.info(Common.addTag("[Encrypt] not encrypt time is: " + (endTime - startTime) + "ms")); | |||
| return this; | |||
| LOGGER.info(Common.addTag("not encrypt time is " + (endTime - startTime) + "ms")); | |||
| break; | |||
| } | |||
| byte uploadCompressType = localFLParameter.getUploadCompressType(); | |||
| if (uploadCompressType != NO_COMPRESS) { | |||
| startTime = System.currentTimeMillis(); | |||
| this.compFmOffset = buildCompFmOffset(featureMaps, trainDataSize); | |||
| this.uploadCompressType = localFLParameter.getUploadCompressType(); | |||
| this.uploadSparseRate = localFLParameter.getUploadSparseRatio(); | |||
| this.nameVecOffset = buildNameVecOffset(updateFeatureName); | |||
| endTime = System.currentTimeMillis(); | |||
| LOGGER.info(Common.addTag("compression time is " + (endTime - startTime) + "ms")); | |||
| return this; | |||
| } | |||
| this.fmOffset = buildFmOffset(featureMaps, updateFeatureName); | |||
| return this; | |||
| } | |||
| private int buildCompFmOffset(Map<String, List<Float>> featureMaps, int trainDataSize) { | |||
| List<CompressWeight> compressWeights = EncodeExecutor.getInstance().encode(featureMaps, trainDataSize); | |||
| if (compressWeights == null || compressWeights.size() == 0) { | |||
| LOGGER.severe("[Compression] the return compressWeights from <encodeExecutor.encode> is " + | |||
| "null, please check"); | |||
| retCode = ResponseCode.RequestError; | |||
| status = FLClientStatus.FAILED; | |||
| throw new IllegalArgumentException(); | |||
| } | |||
| int compFeatureSize = compressWeights.size(); | |||
| int[] compFmOffsets = new int[compFeatureSize]; | |||
| int index = 0; | |||
| for (CompressWeight compressWeight : compressWeights) { | |||
| String weightFullname = compressWeight.getWeightFullname(); | |||
| List<Byte> compressData = compressWeight.getCompressData(); | |||
| float minVal = compressWeight.getMinValue(); | |||
| float maxVal = compressWeight.getMaxValue(); | |||
| byte[] data = new byte[compressData.size()]; | |||
| LOGGER.info(Common.addTag("[updateModel build compressWeight] feature name: " | |||
| + weightFullname + ", feature size: " + data.length)); | |||
| for (int j = 0; j < data.length; j++) { | |||
| data[j] = compressData.get(j); | |||
| } | |||
| int featureName = builder.createString(weightFullname); | |||
| int weight = CompressFeatureMap.createCompressDataVector(builder, data); | |||
| int featureMap = CompressFeatureMap.createCompressFeatureMap(builder, featureName, weight, | |||
| minVal, maxVal); | |||
| LOGGER.info(Common.addTag("[Compression]" + | |||
| " featureName: " + weightFullname + | |||
| ", min_val: " + minVal + | |||
| ", max_val: " + maxVal)); | |||
| compFmOffsets[index] = featureMap; | |||
| index += 1; | |||
| } | |||
| return RequestUpdateModel.createCompressFeatureMapVector(builder, compFmOffsets); | |||
| } | |||
| private int buildNameVecOffset(ArrayList<String> updateFeatureName) { | |||
| int featureSize = updateFeatureName.size(); | |||
| int[] nameVecOffsets = new int[featureSize]; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = updateFeatureName.get(i); | |||
| int featureName = builder.createString(key); | |||
| nameVecOffsets[i] = featureName; | |||
| } | |||
| return RequestUpdateModel.createNameVecVector(builder, nameVecOffsets); | |||
| } | |||
| private int buildFmOffset(Map<String, List<Float>> featureMaps, ArrayList<String> updateFeatureName) { | |||
| int featureSize = updateFeatureName.size(); | |||
| int[] fmOffsets = new int[featureSize]; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = updateFeatureName.get(i); | |||
| List<Float> featureMap = featureMaps.get(key); | |||
| float[] data = new float[featureMap.size()]; | |||
| LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " + | |||
| "size: " + data.length)); | |||
| for (int j = 0; j < data.length; j++) { | |||
| data[j] = featureMap.get(j); | |||
| } | |||
| int featureName = builder.createString(key); | |||
| int weight = FeatureMap.createDataVector(builder, data); | |||
| int featureMapOff = FeatureMap.createFeatureMap(builder, featureName, weight); | |||
| fmOffsets[i] = featureMapOff; | |||
| } | |||
| return RequestUpdateModel.createFeatureMapVector(builder, fmOffsets); | |||
| } | |||
| /** | |||
| @@ -417,6 +497,10 @@ public class UpdateModel { | |||
| RequestUpdateModel.addFlId(this.builder, idOffset); | |||
| RequestUpdateModel.addTimestamp(builder, this.timestampOffset); | |||
| RequestUpdateModel.addIteration(builder, this.iteration); | |||
| RequestUpdateModel.addCompressFeatureMap(builder, this.compFmOffset); | |||
| RequestUpdateModel.addUploadCompressType(builder, this.uploadCompressType); | |||
| RequestUpdateModel.addUploadSparseRate(builder, this.uploadSparseRate); | |||
| RequestUpdateModel.addNameVec(builder, this.nameVecOffset); | |||
| RequestUpdateModel.addFeatureMap(builder, this.fmOffset); | |||
| RequestUpdateModel.addSignature(builder, this.signDataOffset); | |||
| RequestUpdateModel.addUploadLoss(builder, this.uploadLossOffset); | |||
| @@ -0,0 +1,40 @@ | |||
| /* | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.compression; | |||
| import java.util.HashMap; | |||
| import java.util.Map; | |||
| import static mindspore.schema.CompressType.NO_COMPRESS; | |||
| import static mindspore.schema.CompressType.QUANT; | |||
| /** | |||
| * The compress mod. | |||
| * | |||
| * @since 2021-12-21 | |||
| */ | |||
| public class CompressMode { | |||
| // compress type -> num bits | |||
| public static final Map<Byte, Integer> COMPRESS_TYPE_MAP = new HashMap<>(); | |||
| static { | |||
| COMPRESS_TYPE_MAP.put(NO_COMPRESS, -1); | |||
| COMPRESS_TYPE_MAP.put(QUANT, 8); | |||
| } | |||
| } | |||
| @@ -0,0 +1,83 @@ | |||
| /* | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.compression; | |||
| import java.util.List; | |||
| /** | |||
| * Compress Weight Bean | |||
| * | |||
| * @since 2021-12-21 | |||
| */ | |||
| public class CompressWeight { | |||
| private String weightFullname; | |||
| private List<Byte> compressData; | |||
| private float minValue; | |||
| private float maxValue; | |||
| public CompressWeight() { | |||
| } | |||
| public CompressWeight(String weightFullname, List<Byte> compressData, float minValue, float maxValue) { | |||
| this.weightFullname = weightFullname; | |||
| this.compressData = compressData; | |||
| this.minValue = minValue; | |||
| this.maxValue = maxValue; | |||
| } | |||
| public String getWeightFullname() { | |||
| return weightFullname; | |||
| } | |||
| public void setWeightFullname(String weightFullname) { | |||
| this.weightFullname = weightFullname; | |||
| } | |||
| public List<Byte> getCompressData() { | |||
| return compressData; | |||
| } | |||
| public void setCompressData(List<Byte> compressData) { | |||
| this.compressData = compressData; | |||
| } | |||
| public float getMinValue() { | |||
| return minValue; | |||
| } | |||
| public void setMinValue(float minValue) { | |||
| this.minValue = minValue; | |||
| } | |||
| public float getMaxValue() { | |||
| return maxValue; | |||
| } | |||
| public void setMaxValue(float maxValue) { | |||
| this.maxValue = maxValue; | |||
| } | |||
| @Override | |||
| public String toString() { | |||
| return "CompressWeight{" + | |||
| "weightFullname='" + weightFullname + '\'' + | |||
| ", compressData=" + compressData + | |||
| ", minValue=" + minValue + | |||
| ", maxValue=" + maxValue + | |||
| '}'; | |||
| } | |||
| } | |||
| @@ -0,0 +1,115 @@ | |||
| /* | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.compression; | |||
| import com.google.flatbuffers.FlatBufferBuilder; | |||
| import com.mindspore.flclient.Common; | |||
| import com.mindspore.flclient.StartFLJob; | |||
| import mindspore.schema.CompressFeatureMap; | |||
| import mindspore.schema.FeatureMap; | |||
| import mindspore.schema.CompressType; | |||
| import java.nio.ByteBuffer; | |||
| import java.util.ArrayList; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.logging.Logger; | |||
| import static mindspore.schema.CompressType.QUANT; | |||
| /** | |||
| * Compress Executor | |||
| * | |||
| * @since 2021-12-21 | |||
| */ | |||
| public class DecodeExecutor { | |||
| private static final Logger LOGGER = Logger.getLogger(DecodeExecutor.class.toString()); | |||
| private static volatile DecodeExecutor compressExecutor; | |||
| private DecodeExecutor() {} | |||
| public static DecodeExecutor getInstance() { | |||
| if (compressExecutor == null) { | |||
| synchronized (DecodeExecutor.class) { | |||
| if (compressExecutor == null) { | |||
| compressExecutor = new DecodeExecutor(); | |||
| } | |||
| } | |||
| } | |||
| return compressExecutor; | |||
| } | |||
| public List<FeatureMap> deCompressWeight(byte compressType, List<CompressFeatureMap> compressFeatureMapList) { | |||
| if (!CompressMode.COMPRESS_TYPE_MAP.containsKey(compressType)) { | |||
| return new ArrayList<>(); | |||
| } | |||
| LOGGER.info(Common.addTag("[deCompressWeight] create " + CompressType.name(compressType) + " feature map.")); | |||
| int num_bits = CompressMode.COMPRESS_TYPE_MAP.get(compressType); | |||
| if (compressType == QUANT) { | |||
| return deCompressQuantMinMax(compressFeatureMapList, num_bits); | |||
| } | |||
| return new ArrayList<>(); | |||
| } | |||
| private List<FeatureMap> deCompressQuantMinMax(List<CompressFeatureMap> compressFeatureMapList, int num_bits) { | |||
| float temp1 = (float) (Math.pow(2, num_bits) - 1); | |||
| float temp2 = (float) Math.pow(2, num_bits - 1); | |||
| Map<String, float[]> deCompressFeatureMaps = new HashMap<>(); | |||
| int compressFeatureMapLength = compressFeatureMapList.size(); | |||
| for (int i = 0; i < compressFeatureMapLength; i++) { | |||
| CompressFeatureMap compressFeatureMap = compressFeatureMapList.get(i); | |||
| String weightName = compressFeatureMap.weightFullname(); | |||
| int compressDataLength = compressFeatureMap.compressDataLength(); | |||
| List<Byte> compressWeightList = new ArrayList<>(); | |||
| for (int j = 0; j < compressDataLength; j++) { | |||
| compressWeightList.add(compressFeatureMap.compressData(j)); | |||
| } | |||
| float minVal = compressFeatureMap.minVal(); | |||
| float maxVal = compressFeatureMap.maxVal(); | |||
| float scale_value = (float) ((maxVal - minVal) / temp1 + 1e-10); | |||
| float[] params = new float[compressWeightList.size()]; | |||
| for (int j = 0; j < params.length; j++) { | |||
| float val = (compressWeightList.get(j).intValue() + temp2) * scale_value + minVal; | |||
| params[j] = val; | |||
| } | |||
| deCompressFeatureMaps.put(weightName, params); | |||
| } | |||
| List<FeatureMap> featureMaps = new ArrayList<>(); | |||
| for (String weightName : deCompressFeatureMaps.keySet()) { | |||
| FlatBufferBuilder builder = new FlatBufferBuilder(0); | |||
| int weightFullnameOffset = builder.createString(weightName); | |||
| float[] data = deCompressFeatureMaps.get(weightName); | |||
| int dataOffset = FeatureMap.createDataVector(builder, data); | |||
| FeatureMap.startFeatureMap(builder); | |||
| FeatureMap.addWeightFullname(builder, weightFullnameOffset); | |||
| FeatureMap.addData(builder, dataOffset); | |||
| int orc = FeatureMap.endFeatureMap(builder); | |||
| builder.finish(orc); | |||
| ByteBuffer buf = builder.dataBuffer(); | |||
| FeatureMap featureMap = FeatureMap.getRootAsFeatureMap(buf); | |||
| featureMaps.add(featureMap); | |||
| } | |||
| return featureMaps; | |||
| } | |||
| } | |||
| @@ -0,0 +1,167 @@ | |||
| /* | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.compression; | |||
| import com.mindspore.flclient.LocalFLParameter; | |||
| import static mindspore.schema.CompressType.DIFF_SPARSE_QUANT; | |||
| import java.util.ArrayList; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.Queue; | |||
| import java.util.PriorityQueue; | |||
| /** | |||
| * Encode Executor | |||
| * | |||
| * @since 2021-12-21 | |||
| */ | |||
| public class EncodeExecutor { | |||
| private final LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); | |||
| private static volatile EncodeExecutor encodeExecutor; | |||
| private EncodeExecutor() {} | |||
| public static EncodeExecutor getInstance() { | |||
| if (encodeExecutor == null) { | |||
| synchronized (EncodeExecutor.class) { | |||
| if (encodeExecutor == null) { | |||
| encodeExecutor = new EncodeExecutor(); | |||
| } | |||
| } | |||
| } | |||
| return encodeExecutor; | |||
| } | |||
| private static final int multiplier = 2147483647; | |||
| private static final double increment = 4294967294.0; | |||
| private static final int modulo = 48271; | |||
| private List<Integer> constructMaskArray(int paramNum) { | |||
| int seed = localFLParameter.getSeed(); | |||
| float uploadSparseRatio = localFLParameter.getUploadSparseRatio(); | |||
| List<Integer> maskArray = new ArrayList<>(); | |||
| int retain_num = (int) ((float) (paramNum) * uploadSparseRatio); | |||
| for (int i = 0; i < retain_num; ++i) { | |||
| maskArray.add(1); | |||
| } | |||
| for (int i = retain_num; i < paramNum; ++i) { | |||
| maskArray.add(0); | |||
| } | |||
| seed = ((seed + multiplier) * modulo) % multiplier; | |||
| for (int i = 0; i < paramNum; ++i) { | |||
| // generate random number in (0, 1) | |||
| double rand = (double)(seed) / increment + 0.5; | |||
| // update seed | |||
| seed = (seed * modulo) % multiplier; | |||
| int j = (int)(rand * (double)(paramNum - i)) + i; | |||
| int temp = maskArray.get(i); | |||
| maskArray.set(i, maskArray.get(j)); | |||
| maskArray.set(j, temp); | |||
| } | |||
| return maskArray; | |||
| } | |||
| public List<CompressWeight> enDiffSparseQuant(Map<String, List<Float>> featureMaps, int numBits, | |||
| int trainDataSize) { | |||
| List<CompressWeight> compressWeights = new ArrayList<>(); | |||
| // difference encode | |||
| Map<String, float[]> oldFeatureMap = localFLParameter.getOldFeatureMap(); | |||
| Map<String, List<Float>> diffFeatureMaps = new HashMap<>(); | |||
| for (String featureMapName : featureMaps.keySet()) { | |||
| List<Float> diffs = new ArrayList<>(); | |||
| List<Float> featureMap = featureMaps.get(featureMapName); | |||
| float[] dataBeforeTrain = oldFeatureMap.get(featureMapName); | |||
| int length = dataBeforeTrain.length; | |||
| for (int i = 0; i < length; ++i) { | |||
| float diff = featureMap.get(i) - dataBeforeTrain[i] * (float) trainDataSize; | |||
| diffs.add(diff); | |||
| } | |||
| diffFeatureMaps.put(featureMapName, diffs); | |||
| } | |||
| // sparse encode | |||
| int paramNum = 0; | |||
| for (String featureMapName : diffFeatureMaps.keySet()) { | |||
| int weightSize = diffFeatureMaps.get(featureMapName).size(); | |||
| paramNum += weightSize; | |||
| } | |||
| List<Integer> maskArray = constructMaskArray(paramNum); | |||
| Map<String, List<Float>> sparseFeatureMaps = new HashMap<>(); | |||
| int index = 0; | |||
| for (String featureMapName : diffFeatureMaps.keySet()) { | |||
| List<Float> sparseFeatureMap = new ArrayList<>(); | |||
| List<Float> Weight = diffFeatureMaps.get(featureMapName); | |||
| for (Float dataValue : Weight) { | |||
| if (maskArray.get(index) == 1) { | |||
| sparseFeatureMap.add(dataValue); | |||
| } | |||
| index += 1; | |||
| } | |||
| sparseFeatureMaps.put(featureMapName, sparseFeatureMap); | |||
| } | |||
| // quant encode | |||
| float temp1 = (float) (1 << numBits) - 1.0f; | |||
| float temp2 = (float) (1 << (numBits - 1)); | |||
| for (String featureMapName : sparseFeatureMaps.keySet()) { | |||
| CompressWeight compressWeight = new CompressWeight(); | |||
| compressWeight.setWeightFullname(featureMapName); | |||
| List<Float> sparseFeatureMap = sparseFeatureMaps.get(featureMapName); | |||
| // get min and max value | |||
| Float minVal = Float.MAX_VALUE; | |||
| float maxVal = -minVal; | |||
| for (Float value : sparseFeatureMap) { | |||
| if (value < minVal) { | |||
| minVal = value; | |||
| } | |||
| if (value > maxVal) { | |||
| maxVal = value; | |||
| } | |||
| } | |||
| compressWeight.setMinValue(minVal); | |||
| compressWeight.setMaxValue(maxVal); | |||
| float scale_value = (maxVal - minVal) / temp1 + 1e-10f; | |||
| List<Byte> compressData = new ArrayList<>(); | |||
| for (Float aFloat : sparseFeatureMap) { | |||
| compressData.add((byte) (Math.round((aFloat - minVal) / scale_value - temp2))); | |||
| } | |||
| compressWeight.setCompressData(compressData); | |||
| compressWeights.add(compressWeight); | |||
| } | |||
| return compressWeights; | |||
| } | |||
| public List<CompressWeight> encode(Map<String, List<Float>> featureMaps, int trainDataSize) { | |||
| byte uploadCompressType = localFLParameter.getUploadCompressType(); | |||
| if (uploadCompressType == DIFF_SPARSE_QUANT) { | |||
| return enDiffSparseQuant(featureMaps, 8, trainDataSize); | |||
| } | |||
| throw new IllegalArgumentException(); | |||
| } | |||
| } | |||
| @@ -15,8 +15,9 @@ | |||
| """Context for parameter server training mode""" | |||
| import os | |||
| from mindspore._checkparam import Validator | |||
| from mindspore._checkparam import Validator, Rel | |||
| from mindspore._c_expression import PSContext | |||
| from mindspore import log as logger | |||
| _ps_context = None | |||
| @@ -79,6 +80,9 @@ _set_ps_context_func_map = { | |||
| "sign_global_lr": ps_context().set_sign_global_lr, | |||
| "sign_dim_out": ps_context().set_sign_dim_out, | |||
| "checkpoint_dir": ps_context().set_checkpoint_dir, | |||
| "upload_compress_type": ps_context().set_upload_compress_type, | |||
| "upload_sparse_rate": ps_context().set_upload_sparse_rate, | |||
| "download_compress_type": ps_context().set_download_compress_type, | |||
| } | |||
| _get_ps_context_func_map = { | |||
| @@ -126,7 +130,10 @@ _get_ps_context_func_map = { | |||
| "sign_thr_ratio": ps_context().sign_thr_ratio, | |||
| "sign_global_lr": ps_context().sign_global_lr, | |||
| "sign_dim_out": ps_context().sign_dim_out, | |||
| "checkpoint_dir": ps_context().checkpoint_dir | |||
| "checkpoint_dir": ps_context().checkpoint_dir, | |||
| "upload_compress_type": ps_context().upload_compress_type, | |||
| "upload_sparse_rate": ps_context().upload_sparse_rate, | |||
| "download_compress_type": ps_context().download_compress_type, | |||
| } | |||
| _check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", | |||
| @@ -140,6 +147,15 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] | |||
| _check_port_keys = ["scheduler_port", "fl_server_port"] | |||
| _check_string_keys = { | |||
| "upload_compress_type": ["NO_COMPRESS", "DIFF_SPARSE_QUANT"], | |||
| "download_compress_type": ["NO_COMPRESS", "QUANT"], | |||
| } | |||
| _check_float_range_keys = { | |||
| "upload_sparse_rate": {"lower_limit": 0.0, "upper_limit": 1.0, "rel": Rel.INC_RIGHT}, | |||
| } | |||
| def _get_ps_mode_rank(): | |||
| ps_rank = ps_context().ps_rank_id() | |||
| if ps_rank == -1: | |||
| @@ -183,6 +199,7 @@ def _set_ps_context(**kwargs): | |||
| Examples: | |||
| >>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456') | |||
| """ | |||
| kwargs = _check_conflict_value(kwargs) | |||
| for key, value in kwargs.items(): | |||
| if key not in _set_ps_context_func_map: | |||
| raise ValueError("Set PS context keyword %s is not recognized!" % key) | |||
| @@ -287,6 +304,31 @@ def _check_value(key, value): | |||
| if key in _check_positive_float_keys: | |||
| Validator.check_positive_float(value, key) | |||
| if key in _check_string_keys: | |||
| try: | |||
| string_keys = _check_string_keys[key] | |||
| Validator.check_string(value, string_keys) | |||
| except KeyError: | |||
| pass | |||
| if key in _check_float_range_keys: | |||
| try: | |||
| range_keys = _check_float_range_keys[key] | |||
| Validator.check_float_range(value, **range_keys) | |||
| except KeyError: | |||
| pass | |||
| if key in _check_port_keys: | |||
| if value < 1 or value > 65535: | |||
| raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value)) | |||
| def _check_conflict_value(kwargs): | |||
| if "upload_compress_type" in kwargs and " encrypt_type" in kwargs: | |||
| if kwargs["upload_compress_type"] != "NO_COMPRESS" and kwargs["encrypt_type"] in ("SIGNDS", "PW_ENCRYPT"): | |||
| logger.warning("The '{}' and '{}' are conflicted, and in '{}' mode the" | |||
| " 'upload_compress_type' will be 'NO_COMPRESS'".format(kwargs["encrypt_type"], | |||
| kwargs["upload_compress_type"], | |||
| kwargs["encrypt_type"])) | |||
| kwargs["upload_compress_type"] = "NO_COMPRESS" | |||
| return kwargs | |||
| @@ -47,6 +47,16 @@ table FeatureMap{ | |||
| weight_fullname:string; | |||
| data:[float]; | |||
| } | |||
| enum CompressType:byte {NO_COMPRESS = 0, DIFF_SPARSE_QUANT = 1, QUANT = 2} | |||
| table CompressFeatureMap{ | |||
| weight_fullname:string; | |||
| compress_data:[int8]; | |||
| min_val:float; | |||
| max_val:float; | |||
| } | |||
| table RequestFLJob{ | |||
| fl_name:string; | |||
| fl_id:string; | |||
| @@ -58,6 +68,7 @@ table RequestFLJob{ | |||
| equip_cert:string; | |||
| equip_ca_cert:string; | |||
| root_cert:string; | |||
| download_compress_types:[CompressType]; | |||
| } | |||
| table ResponseFLJob { | |||
| retcode:int; | |||
| @@ -68,6 +79,10 @@ table ResponseFLJob { | |||
| fl_plan_config:FLPlan; | |||
| feature_map:[FeatureMap]; | |||
| timestamp:string; | |||
| upload_compress_type:CompressType; | |||
| upload_sparse_rate:float; | |||
| download_compress_type:CompressType; | |||
| compress_feature_map:[CompressFeatureMap]; | |||
| } | |||
| table FLPlan { | |||
| @@ -94,6 +109,10 @@ table RequestUpdateModel{ | |||
| upload_loss:float; | |||
| sign:int; | |||
| index_array:[int]; | |||
| compress_feature_map:[CompressFeatureMap]; | |||
| upload_compress_type:CompressType; | |||
| upload_sparse_rate:float; | |||
| name_vec:[string]; | |||
| } | |||
| table ResponseUpdateModel{ | |||
| @@ -132,6 +151,7 @@ table RequestGetModel{ | |||
| fl_name:string; | |||
| iteration:int; | |||
| timestamp:string; | |||
| download_compress_types:[CompressType]; | |||
| } | |||
| table ResponseGetModel{ | |||
| retcode:int; | |||
| @@ -139,6 +159,8 @@ table ResponseGetModel{ | |||
| iteration:int; | |||
| feature_map:[FeatureMap]; | |||
| timestamp:string; | |||
| download_compress_type:CompressType; | |||
| compress_feature_map:[CompressFeatureMap]; | |||
| } | |||
| table RequestAsyncGetModel{ | |||