/** * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fl/compression/encode_executor.h" #include #include #include #include #include #include #include #include #include #include #include "fl/server/common.h" namespace mindspore { namespace fl { namespace compression { bool CompressExecutor::EnableCompressWeight(const schema::CompressType compressType) { return kCompressTypeMap.count(compressType) > 0; } bool CompressExecutor::construct_compress_weight(std::map *compressWeights, std::map> feature_maps, const schema::CompressType compressType) { if (compressType == schema::CompressType_QUANT) { return quant_min_max(compressWeights, feature_maps, kCompressTypeMap.at(compressType)); } return false; } bool CompressExecutor::quant_min_max(std::map *compressWeights, std::map> feature_maps, size_t num_bits) { auto temp1 = static_cast(1 << num_bits) - 1.0f; auto temp2 = static_cast(1 << (num_bits - 1)); for (const auto &feature_map : feature_maps) { std::string weight_name = feature_map.first; float min_value = 1e10f; float max_value = -min_value; for (const auto &feature : feature_map.second) { if (feature > max_value) { max_value = feature; } if (feature < min_value) { min_value = feature; } } float scale_value = (max_value - min_value) / temp1 + 1e-10f; size_t size = feature_map.second.size(); if (size == 0) { MS_LOG(WARNING) << "The size of parameters is zero."; return false; } CompressWeight compressWeight; for (size_t i = 0; i < size; ++i) { auto round_data = round((feature_map.second[i] - min_value) / scale_value - temp2); // bit pack can be implemented here in the future auto int8_data = int8_t(round_data); compressWeight.compress_data.emplace_back(int8_data); } compressWeight.min_val = min_value; compressWeight.max_val = max_value; compressWeight.compress_data_len = size; (*compressWeights)[weight_name] = compressWeight; } return true; } schema::CompressType CompressExecutor::GetCompressType(const flatbuffers::Vector *download_compress_types) { schema::CompressType compressType = schema::CompressType_NO_COMPRESS; if (download_compress_types == nullptr) { MS_LOG(DEBUG) << "The client does not support current download compress type."; } else { for (size_t i = 0; i < download_compress_types->size(); ++i) { auto download_compress_type = download_compress_types->Get(i); if (download_compress_type == schema::CompressType_QUANT) { compressType = schema::CompressType_QUANT; break; } } } return compressType; } } // namespace compression } // namespace fl } // namespace mindspore