You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

huffman_encode.cc 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/converter/quantizer/huffman_encode.h"
  17. #include <utility>
  18. #include <iostream>
  19. #include "src/dequant.h"
  20. namespace mindspore {
  21. namespace lite {
  22. STATUS HuffmanEncode::DoHuffmanEncode(const ParamValueLitePtr &weight, const std::shared_ptr<PrimitiveC> &primitive_c,
  23. void *quant_datas, const size_t &bit_num) {
  24. if (quant_datas == nullptr) {
  25. MS_LOG(ERROR) << "quant data is nullptr";
  26. return RET_ERROR;
  27. }
  28. auto *raw_datas = static_cast<int8_t *>(quant_datas);
  29. size_t elem_count = weight->tensor_shape_size();
  30. size_t packed_size = elem_count * bit_num;
  31. HuffmanPriorityQueue pq;
  32. auto status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
  33. if (status != RET_OK) {
  34. MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed";
  35. return status;
  36. }
  37. status = BuildHuffmanTree(&pq);
  38. if (status != RET_OK) {
  39. MS_LOG(ERROR) << "BuildHuffmanTree failed";
  40. return status;
  41. }
  42. status = DoHuffmanCompress(raw_datas, elem_count);
  43. if (status != RET_OK) {
  44. MS_LOG(ERROR) << "DoHuffmanCompress failed";
  45. return status;
  46. }
  47. size_t ch_size = huffman_encoded_str_.length();
  48. if (ch_size < packed_size) {
  49. auto encode_data = new (std::nothrow) char[ch_size];
  50. if (encode_data == nullptr) {
  51. MS_LOG(ERROR) << "new char[] failed.";
  52. return RET_MEMORY_FAILED;
  53. }
  54. if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) {
  55. MS_LOG(ERROR) << "memcpy_s failed.";
  56. delete[] encode_data;
  57. return RET_MEMORY_FAILED;
  58. }
  59. weight->SetTensorData(encode_data, ch_size);
  60. primitive_c->set_enable_huffman_code(true);
  61. }
  62. huffman_encoded_str_.clear();
  63. huffman_table_.clear();
  64. return RET_SUCCESS;
  65. }
  66. STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
  67. MS_ASSERT(data != nullptr);
  68. std::map<int8_t, size_t> freq_map;
  69. for (size_t i = 0; i < data_size; i++) {
  70. freq_map[data[i]]++;
  71. }
  72. for (auto &kv : freq_map) {
  73. if (kv.second <= 0) {
  74. continue;
  75. }
  76. auto node = new (std::nothrow) HuffmanNode();
  77. if (node == nullptr) {
  78. MS_LOG(ERROR) << "new HuffmanNode failed.";
  79. return RET_MEMORY_FAILED;
  80. }
  81. this->huffman_nodes_.push_back(node);
  82. node->key = kv.first;
  83. node->freq = kv.second;
  84. node->code = "";
  85. node->left = nullptr;
  86. node->right = nullptr;
  87. node->parent = nullptr;
  88. pq->push(node);
  89. }
  90. // insert pseudo-EOF
  91. auto node = new (std::nothrow) HuffmanNode();
  92. if (node == nullptr) {
  93. MS_LOG(ERROR) << "new HuffmanNode failed.";
  94. return RET_MEMORY_FAILED;
  95. }
  96. this->huffman_nodes_.push_back(node);
  97. node->key = PSEUDO_EOF;
  98. node->freq = 1;
  99. node->code = "";
  100. node->left = nullptr;
  101. node->right = nullptr;
  102. node->parent = nullptr;
  103. pq->push(node);
  104. return RET_OK;
  105. }
  106. void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
  107. if (is_left_node) {
  108. node->code = node->parent->code + "0";
  109. } else {
  110. node->code = node->parent->code + "1";
  111. }
  112. if (node->left == nullptr && node->right == nullptr) {
  113. huffman_table_[node->key] = node->code;
  114. } else {
  115. if (node->left != nullptr) {
  116. GenerateHuffmanTable(node->left, true);
  117. }
  118. if (node->right != nullptr) {
  119. GenerateHuffmanTable(node->right, false);
  120. }
  121. }
  122. }
  123. STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
  124. HuffmanNodePtr root = nullptr;
  125. while (!pq->empty()) {
  126. HuffmanNodePtr first = pq->top();
  127. pq->pop();
  128. if (pq->empty()) {
  129. root = first;
  130. break;
  131. }
  132. HuffmanNodePtr second = pq->top();
  133. pq->pop();
  134. auto new_node = new (std::nothrow) HuffmanNode();
  135. if (new_node == nullptr) {
  136. MS_LOG(ERROR) << "new HuffmanNode failed.";
  137. return RET_MEMORY_FAILED;
  138. }
  139. this->huffman_nodes_.push_back(new_node);
  140. new_node->freq = first->freq + second->freq;
  141. new_node->left = first;
  142. new_node->right = second;
  143. first->parent = new_node;
  144. second->parent = new_node;
  145. pq->push(new_node);
  146. }
  147. if (root == nullptr) {
  148. MS_LOG(ERROR) << "huffman tree root node is nullptr.";
  149. return RET_ERROR;
  150. }
  151. if (root->left != nullptr) {
  152. GenerateHuffmanTable(root->left, true);
  153. }
  154. if (root->right != nullptr) GenerateHuffmanTable(root->right, false);
  155. return RET_OK;
  156. }
  157. STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
  158. unsigned char out_c;
  159. string code_str;
  160. std::map<int, string>::iterator iter;
  161. std::vector<std::string> encode_str = {"", "", ""};
  162. huffman_encoded_str_.clear();
  163. for (iter = huffman_table_.begin(); iter != huffman_table_.end(); ++iter) {
  164. encode_str[0] += std::to_string(iter->first) + " ";
  165. encode_str[1] += iter->second + " ";
  166. }
  167. for (size_t i = 0; i < data_size; i++) {
  168. auto raw_num = input_datas[i];
  169. iter = huffman_table_.find(raw_num);
  170. if (iter != huffman_table_.end()) {
  171. code_str += iter->second;
  172. } else {
  173. MS_LOG(ERROR) << "Can't find the huffman code " << raw_num;
  174. return RET_ERROR;
  175. }
  176. }
  177. iter = huffman_table_.find(PSEUDO_EOF);
  178. if (iter != huffman_table_.end()) {
  179. code_str += iter->second;
  180. } else {
  181. MS_LOG(ERROR) << "Can't find the huffman code pseudo-EOF";
  182. return RET_ERROR;
  183. }
  184. out_c = 0;
  185. for (size_t i = 0; i < code_str.length(); i++) {
  186. auto tmp_c = code_str[i] == '0' ? 0 : 1;
  187. out_c += tmp_c << (7 - (i % 8));
  188. if (0 == (i + 1) % 8 || i == code_str.length() - 1) {
  189. encode_str[2] += out_c;
  190. out_c = 0;
  191. }
  192. }
  193. huffman_encoded_str_ = encode_str[0] + "#" + encode_str[1] + "#" + encode_str[2];
  194. return RET_OK;
  195. }
  196. HuffmanEncode::~HuffmanEncode() {
  197. for (auto &node : this->huffman_nodes_) {
  198. delete node;
  199. }
  200. this->huffman_nodes_.resize(0);
  201. }
  202. } // namespace lite
  203. } // namespace mindspore