You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quantize_util.h 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
  17. #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
  18. #include <memory>
  19. #include <string>
  20. #include <cmath>
  21. #include <array>
  22. #include <vector>
  23. #include <algorithm>
  24. #include <limits>
  25. #include <utility>
  26. #include "tools/converter/quantizer/quantizer.h"
  27. #include "src/ops/primitive_c.h"
  28. #include "include/errorcode.h"
  29. #include "ir/func_graph.h"
  30. #include "ir/anf.h"
  31. #include "include/model.h"
  32. #include "base/base.h"
  33. #include "ir/primitive.h"
  34. #include "abstract/dshape.h"
  35. namespace mindspore {
  36. namespace lite {
  37. namespace quant {
  38. static constexpr size_t UINT8_QUANTIZATION = 8;
  39. static constexpr size_t WEIGHT_INDEX = 1;
  40. /**
  41. * 1. when op's weight size > mWeightSize just skip
  42. * 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization
  43. * 3. when conv/deconv/convdepthwise/deconvdepthwise ops' weight channel size > covWeightQuantChannelThreshold just skip
  44. * */
  45. class QuantStrategy {
  46. public:
  47. explicit QuantStrategy(size_t weightSize, size_t covWeightQuantChannelThreshold = 16);
  48. ~QuantStrategy() = default;
  49. bool CanConvOpQuantized(const CNodePtr &node) const;
  50. bool CanMulOpQuantized(const CNodePtr &node) const;
  51. bool CanOpPostQuantized(AnfNodePtr &node) const;
  52. private:
  53. size_t mWeightSize;
  54. size_t mConvWeightQuantChannelThreshold;
  55. static const std::vector<schema::PrimitiveType> conv_types;
  56. static const std::vector<schema::PrimitiveType> mul_types;
  57. };
  58. constexpr float delta = 0.1;
  59. constexpr float ratio = 10.0;
  60. constexpr int percent = 10;
  61. STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
  62. int quant_min, int num_bits);
  63. STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange = false,
  64. int numBits = UINT8_QUANTIZATION);
  65. std::pair<float, float> OutlierMethod(std::vector<float> min_datas, std::vector<float> max_datas);
  66. std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam);
  67. template <typename T>
  68. T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
  69. MS_ASSERT(quantParam != nullptr);
  70. MS_ASSERT(quantParam->inited);
  71. const auto scale = quantParam->scale;
  72. const auto zeroPoint = quantParam->zeroPoint;
  73. const auto numBit = quantParam->numBits;
  74. const auto narrowRange = quantParam->narrowRange;
  75. double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1);
  76. const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<T>::min()) * scale;
  77. double minLimit;
  78. if (narrowRange) {
  79. minLimit = static_cast<float>(std::numeric_limits<T>::min() + 1 - zeroPoint) * scale;
  80. } else {
  81. minLimit = static_cast<float>(std::numeric_limits<T>::min() - zeroPoint) * scale;
  82. }
  83. return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
  84. double tmp = 0.0f;
  85. if (originData > maxLimit) {
  86. tmp = maxLimit;
  87. } else if (originData < minLimit) {
  88. tmp = minLimit;
  89. } else {
  90. tmp = originData;
  91. }
  92. auto quantData = static_cast<T>(std::round(zeroPoint + tmp / scale));
  93. return quantData;
  94. }();
  95. }
  96. template <typename T>
  97. T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quant_max, int quant_min) {
  98. MS_ASSERT(quantParam != nullptr);
  99. MS_ASSERT(quantParam->inited);
  100. const auto scale = quantParam.scale;
  101. const int zeroPoint = quantParam.zeroPoint;
  102. const auto narrowRange = quantParam.narrowRange;
  103. const int maxLimit = quant_max;
  104. const int minLimit = quant_min;
  105. return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
  106. auto quant_data = std::round(originData / scale + zeroPoint);
  107. if (quant_data > maxLimit) {
  108. quant_data = maxLimit;
  109. } else if (quant_data < minLimit) {
  110. quant_data = minLimit;
  111. }
  112. return static_cast<T>(quant_data);
  113. }();
  114. }
  115. template <typename T>
  116. STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
  117. int quant_max, int quant_min, size_t bitNum, bool per_channel) {
  118. auto dims = weight->tensor_shape();
  119. if (per_channel) {
  120. if (dims.size() != 4 && dims.size() != 2) {
  121. MS_LOG(INFO) << "weight dims size: " << dims.size() << " switch to per-layer quant mode.";
  122. per_channel = false;
  123. } else {
  124. auto op_type = (schema::PrimitiveType)primitive_c->Type();
  125. if (dims.size() == 2 && op_type != schema::PrimitiveType_FullConnection) {
  126. MS_LOG(INFO) << "weight dims size is 2 but op_type is not FullConnection, switch to per-layer quant mode.";
  127. per_channel = false;
  128. }
  129. uint32_t channels = dims[0];
  130. if (channels == 0) {
  131. MS_LOG(ERROR) << "channels is 0";
  132. return RET_ERROR;
  133. }
  134. }
  135. }
  136. std::vector<schema::QuantParamT> quant_params;
  137. size_t elem_count = weight->tensor_shape_size();
  138. auto *raw_datas = static_cast<float *>(weight->tensor_addr());
  139. if (raw_datas == nullptr) {
  140. MS_LOG(ERROR) << "rawDatas is nullptr";
  141. return RET_ERROR;
  142. }
  143. std::vector<T> quant_datas(elem_count);
  144. std::vector<float> dequant_datas(elem_count);
  145. if (per_channel) {
  146. // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
  147. // channel at first
  148. auto channels = dims[0];
  149. if (channels == 0) {
  150. MS_LOG(ERROR) << "channels is zero";
  151. return RET_ERROR;
  152. }
  153. size_t one_filter_size = elem_count / channels;
  154. for (int i = 0; i < channels; i++) {
  155. float min = FLT_MAX;
  156. float max = -FLT_MAX;
  157. // find min and max
  158. for (size_t j = 0; j < one_filter_size; j++) {
  159. auto index = j + i * one_filter_size;
  160. if (index >= elem_count) {
  161. MS_LOG(ERROR) << "over flow!";
  162. return RET_ERROR;
  163. }
  164. min = std::min(min, raw_datas[index]);
  165. max = std::max(max, raw_datas[index]);
  166. }
  167. schema::QuantParamT quant_param;
  168. STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
  169. if (status != RET_OK) {
  170. MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
  171. return status;
  172. }
  173. // do quantization
  174. double average_dequant = 0;
  175. double average_raw = 0;
  176. for (uint32_t j = 0; j < one_filter_size; j++) {
  177. auto index = j + i * one_filter_size;
  178. if (index >= elem_count) {
  179. MS_LOG(ERROR) << "over flow!";
  180. return RET_ERROR;
  181. }
  182. float raw_data = raw_datas[index];
  183. auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
  184. quant_datas[index] = quant_data;
  185. if (quantType == QuantType_WeightQuant) {
  186. float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint);
  187. dequant_datas[index] = dequant_data;
  188. average_dequant += dequant_data;
  189. average_raw += raw_data;
  190. }
  191. }
  192. if (quantType == QuantType_WeightQuant && quant_param.clusters.size() == 0) {
  193. // mean
  194. average_dequant = average_dequant / one_filter_size;
  195. average_raw = average_raw / one_filter_size;
  196. // std
  197. double variance_dequant = 0;
  198. double variance_raw = 0;
  199. for (uint32_t j = 0; j < one_filter_size; j++) {
  200. auto index = j + i * one_filter_size;
  201. if (index >= elem_count) {
  202. MS_LOG(ERROR) << "over flow!";
  203. return RET_ERROR;
  204. }
  205. variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2);
  206. variance_raw += std::pow(raw_datas[index] - average_raw, 2);
  207. }
  208. variance_dequant = std::sqrt(variance_dequant / one_filter_size);
  209. variance_raw = std::sqrt(variance_raw / one_filter_size);
  210. quant_param.varCorr = 1;
  211. if (variance_raw != 0 && variance_dequant != 0) {
  212. auto temp_var_corr = variance_raw / variance_dequant;
  213. if (temp_var_corr > 0 && temp_var_corr < 10) {
  214. quant_param.varCorr = temp_var_corr;
  215. } else {
  216. MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
  217. }
  218. }
  219. quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr;
  220. }
  221. quant_params.emplace_back(quant_param);
  222. }
  223. auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T));
  224. if (ret != EOK) {
  225. MS_LOG(ERROR) << "memcpy error: " << ret;
  226. return RET_ERROR;
  227. }
  228. weight->set_tensor_size(elem_count * sizeof(T));
  229. } else {
  230. // per layer
  231. float min = FLT_MAX;
  232. float max = -FLT_MIN;
  233. for (uint32_t i = 0; i < elem_count; i++) {
  234. // find max min
  235. min = std::min(min, raw_datas[i]);
  236. max = std::max(max, raw_datas[i]);
  237. }
  238. schema::QuantParamT quant_param;
  239. if (quant_param.clusters.size() == 0) {
  240. STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
  241. if (status != RET_OK) {
  242. MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
  243. return status;
  244. }
  245. }
  246. quant_params.emplace_back(quant_param);
  247. // update data and datatype
  248. for (uint32_t i = 0; i < elem_count; i++) {
  249. float raw_data = raw_datas[i];
  250. if (quant_param.clusters.size() == 0) {
  251. auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
  252. quant_datas[i] = quant_data;
  253. }
  254. }
  255. auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T));
  256. if (ret != EOK) {
  257. MS_LOG(ERROR) << "memcpy error: " << ret;
  258. return RET_ERROR;
  259. }
  260. weight->set_tensor_size(elem_count * sizeof(T));
  261. }
  262. if (quant_params.empty()) {
  263. MS_LOG(ERROR) << "quant_params empty";
  264. return RET_ERROR;
  265. }
  266. primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params);
  267. return RET_OK;
  268. }
  269. STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);
  270. schema::PrimitiveType NodePrimitiveType(CNodePtr cnode);
  271. } // namespace quant
  272. } // namespace lite
  273. } // namespace mindspore
  274. #endif