You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

encode_executor.cc 3.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. /**
  2. * Copyright 2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/compression/encode_executor.h"
  17. #include <arpa/inet.h>
  18. #include <cstdio>
  19. #include <cstdlib>
  20. #include <cstring>
  21. #include <functional>
  22. #include <algorithm>
  23. #include <regex>
  24. #include <map>
  25. #include <utility>
  26. #include <vector>
  27. #include "fl/server/common.h"
  28. namespace mindspore {
  29. namespace fl {
  30. namespace compression {
  31. bool CompressExecutor::EnableCompressWeight(const schema::CompressType compressType) {
  32. return kCompressTypeMap.count(compressType) > 0;
  33. }
  34. bool CompressExecutor::construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights,
  35. std::map<std::string, std::vector<float>> feature_maps,
  36. const schema::CompressType compressType) {
  37. if (compressType == schema::CompressType_QUANT) {
  38. return quant_min_max(compressWeights, feature_maps, kCompressTypeMap.at(compressType));
  39. }
  40. return false;
  41. }
  42. bool CompressExecutor::quant_min_max(std::map<std::string, CompressWeight> *compressWeights,
  43. std::map<std::string, std::vector<float>> feature_maps, size_t num_bits) {
  44. auto temp1 = static_cast<float>(1 << num_bits) - 1.0f;
  45. auto temp2 = static_cast<float>(1 << (num_bits - 1));
  46. for (const auto &feature_map : feature_maps) {
  47. std::string weight_name = feature_map.first;
  48. float min_value = 1e10f;
  49. float max_value = -min_value;
  50. for (const auto &feature : feature_map.second) {
  51. if (feature > max_value) {
  52. max_value = feature;
  53. }
  54. if (feature < min_value) {
  55. min_value = feature;
  56. }
  57. }
  58. float scale_value = (max_value - min_value) / temp1 + 1e-10f;
  59. size_t size = feature_map.second.size();
  60. if (size == 0) {
  61. MS_LOG(WARNING) << "The size of parameters is zero.";
  62. return false;
  63. }
  64. CompressWeight compressWeight;
  65. for (size_t i = 0; i < size; ++i) {
  66. auto round_data = round((feature_map.second[i] - min_value) / scale_value - temp2);
  67. // bit pack can be implemented here in the future
  68. auto int8_data = int8_t(round_data);
  69. compressWeight.compress_data.emplace_back(int8_data);
  70. }
  71. compressWeight.min_val = min_value;
  72. compressWeight.max_val = max_value;
  73. compressWeight.compress_data_len = size;
  74. (*compressWeights)[weight_name] = compressWeight;
  75. }
  76. return true;
  77. }
  78. schema::CompressType CompressExecutor::GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types) {
  79. schema::CompressType compressType = schema::CompressType_NO_COMPRESS;
  80. if (download_compress_types == nullptr) {
  81. MS_LOG(DEBUG) << "The client does not support current download compress type.";
  82. } else {
  83. for (size_t i = 0; i < download_compress_types->size(); ++i) {
  84. auto download_compress_type = download_compress_types->Get(i);
  85. if (download_compress_type == schema::CompressType_QUANT) {
  86. compressType = schema::CompressType_QUANT;
  87. break;
  88. }
  89. }
  90. }
  91. return compressType;
  92. }
  93. } // namespace compression
  94. } // namespace fl
  95. } // namespace mindspore